Skip to content

Commit 4f75dd7

Browse files
committed
Use powf(1/3) instead of cbrt()
Like 50% faster, not sure why All tests still pass
1 parent a02481a commit 4f75dd7

File tree

1 file changed

+21
-10
lines changed

1 file changed

+21
-10
lines changed

src/lib.rs

+21-10
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,6 @@ pub trait DType:
8080
fn spowf(self, rhs: Self) -> Self;
8181
fn rem_euclid(self, rhs: Self) -> Self;
8282

83-
fn sqrt(self) -> Self;
84-
fn cbrt(self) -> Self;
85-
8683
fn abs(self) -> Self;
8784
fn trunc(self) -> Self;
8885
fn max(self, other: Self) -> Self;
@@ -94,6 +91,19 @@ pub trait DType:
9491
fn to_radians(self) -> Self;
9592
fn atan2(self, rhs: Self) -> Self;
9693

94+
fn sqrt(self) -> Self {
95+
self.powf((1.0 / 2.0).to_dt())
96+
}
97+
fn cbrt(self) -> Self {
98+
self.powf((1.0 / 3.0).to_dt())
99+
}
100+
fn ssqrt(self) -> Self {
101+
self.spowf((1.0 / 2.0).to_dt())
102+
}
103+
fn scbrt(self) -> Self {
104+
self.spowf((1.0 / 3.0).to_dt())
105+
}
106+
97107
fn _fma(self, mul: Self, add: Self) -> Self;
98108
/// Fused multiply-add if "fma" is enabled in rustc
99109
fn fma(self, mul: Self, add: Self) -> Self {
@@ -121,12 +131,6 @@ macro_rules! impl_float {
121131
fn rem_euclid(self, rhs: Self) -> Self {
122132
self.rem_euclid(rhs)
123133
}
124-
fn sqrt(self) -> Self {
125-
self.sqrt()
126-
}
127-
fn cbrt(self) -> Self {
128-
self.cbrt()
129-
}
130134
fn abs(self) -> Self {
131135
self.abs()
132136
}
@@ -154,6 +158,13 @@ macro_rules! impl_float {
154158
fn atan2(self, rhs: Self) -> Self {
155159
self.atan2(rhs)
156160
}
161+
fn sqrt(self) -> Self {
162+
self.sqrt()
163+
}
164+
// 50% slower than powf/spowf?
165+
//fn cbrt(self) -> Self {
166+
// self.cbrt()
167+
//}
157168
fn _fma(self, mul: Self, add: Self) -> Self {
158169
self.mul_add(mul, add)
159170
}
@@ -1133,7 +1144,7 @@ where
11331144
Channels<N>: ValidChannels,
11341145
{
11351146
let mut lms = matmul3t([pixel[0], pixel[1], pixel[2]], OKLAB_M1);
1136-
lms.iter_mut().for_each(|c| *c = c.cbrt());
1147+
lms.iter_mut().for_each(|c| *c = c.scbrt());
11371148
[pixel[0], pixel[1], pixel[2]] = matmul3t(lms, OKLAB_M2);
11381149
}
11391150

0 commit comments

Comments
 (0)