Skip to content

Commit aaa2742

Browse files
committed
Standardize matrix functions
Only one *correct* matrix multiply function now All matricies except Oklab are fed through compile-time transpose so the visually written columns match up with m[vec][elem]
1 parent cd8395b commit aaa2742

File tree

1 file changed

+56
-52
lines changed

1 file changed

+56
-52
lines changed

src/lib.rs

+56-52
Original file line numberDiff line numberDiff line change
@@ -245,12 +245,33 @@ const JZAZBZ_P: f32 = 1.7 * PQEOTF_M2;
245245
// ### CONSTS ### }}}
246246

247247
// ### MATRICES ### {{{
248+
249+
/// Its easier to write matricies visually then transpose them so they can be indexed per vector
250+
/// [X1, X2] -> [X1, Y1]
251+
/// [Y1, Y2] [X2, Y2]
252+
const fn t(m: [[f32; 3]; 3]) -> [[f32; 3]; 3] {
253+
[
254+
[m[0][0], m[1][0], m[2][0]],
255+
[m[0][1], m[1][1], m[2][1]],
256+
[m[0][2], m[1][2], m[2][2]],
257+
]
258+
}
259+
260+
/// Matrix Multiply
261+
fn mm<T: DType>(m: [[f32; 3]; 3], p: [T; 3]) -> [T; 3] {
262+
[
263+
p[0].fma(m[0][0].to_dt(), p[1].fma(m[1][0].to_dt(), p[2] * m[2][0].to_dt())),
264+
p[0].fma(m[0][1].to_dt(), p[1].fma(m[1][1].to_dt(), p[2] * m[2][1].to_dt())),
265+
p[0].fma(m[0][2].to_dt(), p[1].fma(m[1][2].to_dt(), p[2] * m[2][2].to_dt())),
266+
]
267+
}
268+
248269
// CIE XYZ
249-
const XYZ65_MAT: [[f32; 3]; 3] = [
270+
const XYZ65_MAT: [[f32; 3]; 3] = t([
250271
[0.4124, 0.3576, 0.1805],
251272
[0.2126, 0.7152, 0.0722],
252273
[0.0193, 0.1192, 0.9505],
253-
];
274+
]);
254275

255276
// Original commonly used inverted array
256277
// const XYZ65_MAT_INV: [[f32; 3]; 3] = [
@@ -260,13 +281,14 @@ const XYZ65_MAT: [[f32; 3]; 3] = [
260281
// ];
261282

262283
// Higher precision invert using numpy. Helps with back conversions
263-
const XYZ65_MAT_INV: [[f32; 3]; 3] = [
284+
const XYZ65_MAT_INV: [[f32; 3]; 3] = t([
264285
[3.2406254773, -1.5372079722, -0.4986285987],
265286
[-0.9689307147, 1.8757560609, 0.0415175238],
266287
[0.0557101204, -0.2040210506, 1.0569959423],
267-
];
288+
]);
268289

269290
// OKLAB
291+
// They appear to be provided already transposed for code in the blog post
270292
const OKLAB_M1: [[f32; 3]; 3] = [
271293
[0.8189330101, 0.0329845436, 0.0482003018],
272294
[0.3618667424, 0.9293118715, 0.2643662691],
@@ -289,68 +311,50 @@ const OKLAB_M2_INV: [[f32; 3]; 3] = [
289311
];
290312

291313
// JzAzBz
292-
const JZAZBZ_M1: [[f32; 3]; 3] = [
314+
const JZAZBZ_M1: [[f32; 3]; 3] = t([
293315
[0.41478972, 0.579999, 0.0146480],
294316
[-0.2015100, 1.120649, 0.0531008],
295317
[-0.0166008, 0.264800, 0.6684799],
296-
];
297-
const JZAZBZ_M2: [[f32; 3]; 3] = [
318+
]);
319+
const JZAZBZ_M2: [[f32; 3]; 3] = t([
298320
[0.500000, 0.500000, 0.000000],
299321
[3.524000, -4.066708, 0.542708],
300322
[0.199076, 1.096799, -1.295875],
301-
];
323+
]);
302324

303-
const JZAZBZ_M1_INV: [[f32; 3]; 3] = [
325+
const JZAZBZ_M1_INV: [[f32; 3]; 3] = t([
304326
[1.9242264358, -1.0047923126, 0.037651404],
305327
[0.3503167621, 0.7264811939, -0.0653844229],
306328
[-0.090982811, -0.3127282905, 1.5227665613],
307-
];
308-
const JZAZBZ_M2_INV: [[f32; 3]; 3] = [
329+
]);
330+
const JZAZBZ_M2_INV: [[f32; 3]; 3] = t([
309331
[1., 0.1386050433, 0.0580473162],
310332
[1., -0.1386050433, -0.0580473162],
311333
[1., -0.096019242, -0.8118918961],
312-
];
334+
]);
313335

314336
// ICtCp
315-
const ICTCP_M1: [[f32; 3]; 3] = [
337+
const ICTCP_M1: [[f32; 3]; 3] = t([
316338
[1688. / 4096., 2146. / 4096., 262. / 4096.],
317339
[683. / 4096., 2951. / 4096., 462. / 4096.],
318340
[99. / 4096., 309. / 4096., 3688. / 4096.],
319-
];
320-
const ICTCP_M2: [[f32; 3]; 3] = [
341+
]);
342+
const ICTCP_M2: [[f32; 3]; 3] = t([
321343
[2048. / 4096., 2048. / 4096., 0. / 4096.],
322344
[6610. / 4096., -13613. / 4096., 7003. / 4096.],
323345
[17933. / 4096., -17390. / 4096., -543. / 4096.],
324-
];
346+
]);
325347

326-
const ICTCP_M1_INV: [[f32; 3]; 3] = [
348+
const ICTCP_M1_INV: [[f32; 3]; 3] = t([
327349
[3.4366066943, -2.5064521187, 0.0698454243],
328350
[-0.7913295556, 1.9836004518, -0.1922708962],
329351
[-0.0259498997, -0.0989137147, 1.1248636144],
330-
];
331-
const ICTCP_M2_INV: [[f32; 3]; 3] = [
352+
]);
353+
const ICTCP_M2_INV: [[f32; 3]; 3] = t([
332354
[1., 0.008609037, 0.111029625],
333355
[1., -0.008609037, -0.111029625],
334356
[1., 0.5600313357, -0.320627175],
335-
];
336-
337-
/// 3 * 3x3 Matrix multiply with vector transposed, ie pixel @ matrix
338-
fn matmul3t<T: DType>(p: [T; 3], m: [[f32; 3]; 3]) -> [T; 3] {
339-
[
340-
p[0].fma(m[0][0].to_dt(), p[1].fma(m[1][0].to_dt(), p[2] * m[2][0].to_dt())),
341-
p[0].fma(m[0][1].to_dt(), p[1].fma(m[1][1].to_dt(), p[2] * m[2][1].to_dt())),
342-
p[0].fma(m[0][2].to_dt(), p[1].fma(m[1][2].to_dt(), p[2] * m[2][2].to_dt())),
343-
]
344-
}
345-
346-
/// Transposed 3 * 3x3 matrix multiply, ie matrix @ pixel
347-
fn matmul3<T: DType>(m: [[f32; 3]; 3], p: [T; 3]) -> [T; 3] {
348-
[
349-
p[0].fma(m[0][0].to_dt(), p[1].fma(m[0][1].to_dt(), p[2] * m[0][2].to_dt())),
350-
p[0].fma(m[1][0].to_dt(), p[1].fma(m[1][1].to_dt(), p[2] * m[1][2].to_dt())),
351-
p[0].fma(m[2][0].to_dt(), p[1].fma(m[2][1].to_dt(), p[2] * m[2][2].to_dt())),
352-
]
353-
}
357+
]);
354358
// ### MATRICES ### }}}
355359

356360
// ### TRANSFER FUNCTIONS ### {{{
@@ -1112,7 +1116,7 @@ pub fn lrgb_to_xyz<T: DType, const N: usize>(pixel: &mut [T; N])
11121116
where
11131117
Channels<N>: ValidChannels,
11141118
{
1115-
[pixel[0], pixel[1], pixel[2]] = matmul3(XYZ65_MAT, [pixel[0], pixel[1], pixel[2]])
1119+
[pixel[0], pixel[1], pixel[2]] = mm(XYZ65_MAT, [pixel[0], pixel[1], pixel[2]])
11161120
}
11171121

11181122
/// Convert from CIE XYZ to CIE LAB.
@@ -1147,9 +1151,9 @@ pub fn xyz_to_oklab<T: DType, const N: usize>(pixel: &mut [T; N])
11471151
where
11481152
Channels<N>: ValidChannels,
11491153
{
1150-
let mut lms = matmul3t([pixel[0], pixel[1], pixel[2]], OKLAB_M1);
1154+
let mut lms = mm(OKLAB_M1, [pixel[0], pixel[1], pixel[2]]);
11511155
lms.iter_mut().for_each(|c| *c = c.scbrt());
1152-
[pixel[0], pixel[1], pixel[2]] = matmul3t(lms, OKLAB_M2);
1156+
[pixel[0], pixel[1], pixel[2]] = mm(OKLAB_M2, lms);
11531157
}
11541158

11551159
/// Convert CIE XYZ to JzAzBz
@@ -1159,7 +1163,7 @@ pub fn xyz_to_jzazbz<T: DType, const N: usize>(pixel: &mut [T; N])
11591163
where
11601164
Channels<N>: ValidChannels,
11611165
{
1162-
let mut lms = matmul3(
1166+
let mut lms = mm(
11631167
JZAZBZ_M1,
11641168
[
11651169
pixel[0].fma(JZAZBZ_B.to_dt(), T::ff32(-JZAZBZ_B + 1.0) * pixel[2]),
@@ -1170,7 +1174,7 @@ where
11701174

11711175
lms.iter_mut().for_each(|e| *e = pqz_oetf(*e));
11721176

1173-
let lab = matmul3(JZAZBZ_M2, lms);
1177+
let lab = mm(JZAZBZ_M2, lms);
11741178

11751179
pixel[0] = (T::ff32(1.0 + JZAZBZ_D) * lab[0]) / lab[0].fma(JZAZBZ_D.to_dt(), 1.0.to_dt()) - JZAZBZ_D0.to_dt();
11761180
pixel[1] = lab[1];
@@ -1200,10 +1204,10 @@ where
12001204
// };
12011205
// pixel.iter_mut().for_each(|c| bt2020(c));
12021206

1203-
let mut lms = matmul3(ICTCP_M1, [pixel[0], pixel[1], pixel[2]]);
1207+
let mut lms = mm(ICTCP_M1, [pixel[0], pixel[1], pixel[2]]);
12041208
// lms prime
12051209
lms.iter_mut().for_each(|c| *c = pq_oetf(*c));
1206-
[pixel[0], pixel[1], pixel[2]] = matmul3(ICTCP_M2, lms);
1210+
[pixel[0], pixel[1], pixel[2]] = mm(ICTCP_M2, lms);
12071211
}
12081212

12091213
/// Converts an LAB based space to a cylindrical representation.
@@ -1334,7 +1338,7 @@ pub fn xyz_to_lrgb<T: DType, const N: usize>(pixel: &mut [T; N])
13341338
where
13351339
Channels<N>: ValidChannels,
13361340
{
1337-
[pixel[0], pixel[1], pixel[2]] = matmul3(XYZ65_MAT_INV, [pixel[0], pixel[1], pixel[2]])
1341+
[pixel[0], pixel[1], pixel[2]] = mm(XYZ65_MAT_INV, [pixel[0], pixel[1], pixel[2]])
13381342
}
13391343

13401344
/// Convert from CIE LAB to CIE XYZ.
@@ -1369,9 +1373,9 @@ pub fn oklab_to_xyz<T: DType, const N: usize>(pixel: &mut [T; N])
13691373
where
13701374
Channels<N>: ValidChannels,
13711375
{
1372-
let mut lms = matmul3t([pixel[0], pixel[1], pixel[2]], OKLAB_M2_INV);
1376+
let mut lms = mm(OKLAB_M2_INV, [pixel[0], pixel[1], pixel[2]]);
13731377
lms.iter_mut().for_each(|c| *c = c.powi(3));
1374-
[pixel[0], pixel[1], pixel[2]] = matmul3t(lms, OKLAB_M1_INV);
1378+
[pixel[0], pixel[1], pixel[2]] = mm(OKLAB_M1_INV, lms);
13751379
}
13761380

13771381
/// Convert JzAzBz to CIE XYZ
@@ -1381,7 +1385,7 @@ pub fn jzazbz_to_xyz<T: DType, const N: usize>(pixel: &mut [T; N])
13811385
where
13821386
Channels<N>: ValidChannels,
13831387
{
1384-
let mut lms = matmul3(
1388+
let mut lms = mm(
13851389
JZAZBZ_M2_INV,
13861390
[
13871391
(pixel[0] + JZAZBZ_D0.to_dt())
@@ -1393,7 +1397,7 @@ where
13931397

13941398
lms.iter_mut().for_each(|c| *c = pqz_eotf(*c));
13951399

1396-
[pixel[0], pixel[1], pixel[2]] = matmul3(JZAZBZ_M1_INV, lms);
1400+
[pixel[0], pixel[1], pixel[2]] = mm(JZAZBZ_M1_INV, lms);
13971401

13981402
pixel[0] = pixel[2].fma((JZAZBZ_B - 1.0).to_dt(), pixel[0]) / JZAZBZ_B.to_dt();
13991403
pixel[1] = pixel[0].fma((JZAZBZ_G - 1.0).to_dt(), pixel[1]) / JZAZBZ_G.to_dt();
@@ -1415,10 +1419,10 @@ where
14151419
Channels<N>: ValidChannels,
14161420
{
14171421
// lms prime
1418-
let mut lms = matmul3(ICTCP_M2_INV, [pixel[0], pixel[1], pixel[2]]);
1422+
let mut lms = mm(ICTCP_M2_INV, [pixel[0], pixel[1], pixel[2]]);
14191423
// non-prime lms
14201424
lms.iter_mut().for_each(|c| *c = pq_eotf(*c));
1421-
[pixel[0], pixel[1], pixel[2]] = matmul3(ICTCP_M1_INV, lms);
1425+
[pixel[0], pixel[1], pixel[2]] = mm(ICTCP_M1_INV, lms);
14221426
}
14231427

14241428
/// Retrieves an LAB based space from its cylindrical representation.

0 commit comments

Comments
 (0)