1use serde::{Deserialize, Serialize};
5use serde_with::serde_as;
6use std::fmt;
7
8use crate::{Full, GeneralMatrix, Invertible, MatMul, QuadraticForm, SquareMatrix};
9
10mod ops;
12
13pub use crate::diagonal::DiagonalMatrix;
14
15#[serde_as]
26#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
27pub struct Matrix<const N: usize, const M: usize> {
28 #[serde_as(as = "[[_; M]; N]")]
30 pub rows: [[f64; M]; N],
31}
32pub type Matrix22 = Matrix<2, 2>;
34pub type Matrix33 = Matrix<3, 3>;
36pub type Matrix44 = Matrix<4, 4>;
38
39const C_STAR: f64 = 0.923_879_532_511_286_8;
41const S_STAR: f64 = 0.382_683_432_365_089_8;
43
44impl<const N: usize, const M: usize> GeneralMatrix for Matrix<N, M> {
45 #[inline]
55 fn zeros() -> Self {
56 Self {
57 rows: std::array::from_fn(|_| std::array::from_fn(|_| 0.0)),
58 }
59 }
60 #[inline]
61 fn shape(&self) -> (usize, usize) {
62 (self.n_rows(), self.n_columns())
63 }
64}
65
66impl<const N: usize, const M: usize> Full for Matrix<N, M> {
67 #[inline]
77 fn full(value: f64) -> Self {
78 Self {
79 rows: std::array::from_fn(|_| std::array::from_fn(|_| value)),
80 }
81 }
82}
83
84impl<const N: usize> SquareMatrix for Matrix<N, N> {
85 #[inline]
93 fn identity() -> Self {
94 Self {
95 rows: std::array::from_fn(|i| std::array::from_fn(|j| if i == j { 1.0 } else { 0.0 })),
96 }
97 }
98}
99
100impl<const N: usize, const M: usize, const K: usize> MatMul<Matrix<M, K>> for Matrix<N, M> {
101 type Output = Matrix<N, K>;
102 #[inline]
121 fn matmul(&self, rhs: &Matrix<M, K>) -> Self::Output {
122 let mut result = Self::Output::zeros();
123 for n in 0..N {
124 for k in 0..K {
125 for m in 0..M {
126 result.rows[n][k] += self.rows[n][m] * rhs.rows[m][k];
127 }
128 }
129 }
130
131 result
132 }
133}
134
135impl<const N: usize, const M: usize> MatMul<DiagonalMatrix<M>> for Matrix<N, M> {
136 type Output = Matrix<N, M>;
137
138 #[inline]
158 fn matmul(&self, rhs: &DiagonalMatrix<M>) -> Self::Output {
159 let mut result = Self::Output::zeros();
160 for (i, row) in result.rows.iter_mut().enumerate().take(N) {
161 for j in 0..M {
162 row[j] = self.rows[i][j] * rhs[j];
163 }
164 }
165 result
166 }
167}
168
169impl<const N: usize, const M: usize> Matrix<N, M> {
170 #[inline]
187 #[must_use]
188 pub fn transpose(&self) -> Matrix<M, N> {
189 Matrix {
190 rows: std::array::from_fn(|j| std::array::from_fn(|i| self[(i, j)])),
191 }
192 }
193
194 #[inline]
209 #[must_use]
210 pub fn map_rows<F>(self, f: F) -> Self
211 where
212 F: FnMut([f64; M]) -> [f64; M],
213 {
214 Self {
215 rows: self.rows.map(f),
216 }
217 }
218
219 #[inline]
234 #[must_use]
235 pub fn map_columns<F>(self, f: F) -> Self
236 where
237 F: FnMut([f64; N]) -> [f64; N],
238 {
239 self.clone().transpose().map_rows(f).transpose()
240 }
241 #[inline]
251 #[must_use]
252 pub fn map_elements<F>(self, f: F) -> Self
253 where
254 F: Fn(f64) -> f64,
255 {
256 Self {
257 rows: self.rows.map(|v| v.map(&f)),
258 }
259 }
260
261 #[inline]
281 pub fn iter_elements(&self) -> impl Iterator<Item = f64> {
282 self.rows.iter().flat_map(|row| row.iter().copied())
283 }
284
285 #[inline]
303 pub fn iter_elements_mut(&mut self) -> impl Iterator<Item = &mut f64> {
304 self.rows.iter_mut().flat_map(|row| row.iter_mut())
305 }
306
307 #[inline]
322 pub fn iter_rows(&self) -> impl Iterator<Item = [f64; M]> {
323 self.rows.iter().copied()
324 }
325
326 #[inline]
345 pub fn fold_elements<B, F>(self, init: B, mut f: F) -> B
346 where
347 F: FnMut(B, f64) -> B,
348 {
349 let mut accum = init;
350 for x in self.iter_elements() {
351 accum = f(accum, x);
352 }
353 accum
354 }
355
356 #[inline]
382 pub fn fold_rows<B, F>(self, init: B, mut f: F) -> B
383 where
384 F: FnMut(B, [f64; M]) -> B,
385 {
386 let mut accum = init;
387 for x in self.iter_rows() {
388 accum = f(accum, x);
389 }
390 accum
391 }
392
393 #[must_use]
405 #[inline]
406 pub const fn n_rows(&self) -> usize {
407 N
408 }
409 #[must_use]
421 #[inline]
422 pub const fn n_columns(&self) -> usize {
423 M
424 }
425}
426impl<const N: usize> Matrix<N, N> {
427 #[inline]
442 #[must_use]
443 pub fn with_diagonal(diagonal: [f64; N]) -> Self {
444 DiagonalMatrix { elements: diagonal }.to_dense()
445 }
446
447 #[must_use]
465 #[inline]
466 pub fn determinant(&self) -> f64 {
467 #[inline]
469 fn det2(a: f64, b: f64, c: f64, d: f64) -> f64 {
470 a * d - b * c
471 }
472 #[inline]
475 fn det_recursive_noslice<const N: usize>(
476 matrix: &Matrix<N, N>,
477 row: usize,
478 col_indices: [usize; N],
479 minor_size: usize,
480 ) -> f64 {
481 if minor_size == 4 {
483 let r = matrix.rows;
484 let c = col_indices;
485
486 let (i0, i1, i2, i3) = (row, row + 1, row + 2, row + 3);
488 let [j0, j1, j2, j3] = c[..4] else {
489 unreachable!() };
491
492 let m0 = det2(r[i2][j2], r[i2][j3], r[i3][j2], r[i3][j3]);
493 let m1 = det2(r[i2][j1], r[i2][j3], r[i3][j1], r[i3][j3]);
494 let m2 = det2(r[i2][j1], r[i2][j2], r[i3][j1], r[i3][j2]);
495 let m3 = det2(r[i2][j0], r[i2][j3], r[i3][j0], r[i3][j3]);
496 let m4 = det2(r[i2][j0], r[i2][j2], r[i3][j0], r[i3][j2]);
497 let m5 = det2(r[i2][j0], r[i2][j1], r[i3][j0], r[i3][j1]);
498
499 return r[i0][j0] * (r[i1][j1] * m0 - r[i1][j2] * m1 + r[i1][j3] * m2)
500 - r[i0][j1] * (r[i1][j0] * m0 - r[i1][j2] * m3 + r[i1][j3] * m4)
501 + r[i0][j2] * (r[i1][j0] * m1 - r[i1][j1] * m3 + r[i1][j3] * m5)
502 - r[i0][j3] * (r[i1][j0] * m2 - r[i1][j1] * m4 + r[i1][j2] * m5);
503 }
504
505 (0..minor_size).fold(0.0, |acc, idx| {
506 let minor_size = minor_size - 1;
507 let mut minor_cols = [0; N];
508 for j in 0..minor_size {
509 minor_cols[j] = col_indices[j + usize::from(j >= idx)];
511 }
512
513 let sign = if idx % 2 == 0 { 1.0 } else { -1.0 };
514 acc + sign
515 * matrix.rows[row][col_indices[idx]]
516 * det_recursive_noslice(matrix, row + 1, minor_cols, minor_size)
517 })
518 }
519 match N {
521 0 => return 0.0,
522 1 => return self[(0, 0)],
523 2 => return det2(self[(0, 0)], self[(1, 0)], self[(0, 1)], self[(1, 1)]),
524 3 => {
525 return self[(0, 0)] * det2(self[(1, 1)], self[(1, 2)], self[(2, 1)], self[(2, 2)])
526 - self[(0, 1)] * det2(self[(1, 0)], self[(1, 2)], self[(2, 0)], self[(2, 2)])
527 + self[(0, 2)] * det2(self[(1, 0)], self[(1, 1)], self[(2, 0)], self[(2, 1)]);
528 }
529 _ => (),
530 }
531
532 let col_indices = std::array::from_fn(|i| i);
533 det_recursive_noslice(self, 0, col_indices, N)
534 }
535
536 #[must_use]
550 #[inline]
551 pub fn trace(&self) -> f64 {
552 std::array::from_fn::<_, N, _>(|i| self[(i, i)])
553 .iter()
554 .sum()
555 }
556
557 #[must_use]
575 #[inline]
576 pub fn powi(&self, n: u32) -> Self {
577 (0..n).fold(Self::identity(), |acc, _| acc.matmul(self))
578 }
579
580 #[must_use]
596 #[inline]
597 pub fn diagonal(&self) -> DiagonalMatrix<N> {
598 DiagonalMatrix {
599 elements: std::array::from_fn(|i| self.rows[i][i]),
600 }
601 }
602}
603
604impl<const N: usize> QuadraticForm<N> for Matrix<N, N> {
605 #[inline]
606 fn compute_quadratic_form(&self, x: &[f64; N]) -> f64 {
607 let mut result = 0.0;
608
609 for i in 0..N {
610 for j in 0..N {
611 result += x[i] * self[(i, j)] * x[j];
612 }
613 }
614 result
615 }
616}
617
618impl Invertible for Matrix<2, 2> {
619 #[inline]
634 fn inverse(&self) -> Option<Self> {
635 let det = self.determinant();
636 if det == 0.0 {
637 None
638 } else {
639 let inv_det = det.recip();
640 Some(Self {
641 rows: [
642 [inv_det * self.rows[1][1], inv_det * -self.rows[0][1]],
643 [inv_det * -self.rows[1][0], inv_det * self.rows[0][0]],
644 ],
645 })
646 }
647 }
648}
649
650impl Invertible for Matrix<3, 3> {
651 #[inline]
666 fn inverse(&self) -> Option<Self> {
667 #[inline]
668 fn cross(u: [f64; 3], v: [f64; 3]) -> [f64; 3] {
669 [
670 u[1] * v[2] - u[2] * v[1],
671 u[2] * v[0] - u[0] * v[2],
672 u[0] * v[1] - u[1] * v[0],
673 ]
674 }
675 let [x0, x1, x2] = self.rows;
676 let det = self.determinant();
677 if det == 0.0 {
678 return None;
679 }
680 let rows = [cross(x1, x2), cross(x2, x0), cross(x0, x1)];
681 Some(det.recip() * Self { rows }.transpose())
682 }
683}
684impl Invertible for Matrix<4, 4> {
685 #[inline]
701 fn inverse(&self) -> Option<Self> {
702 let det = self.determinant();
703 if det == 0.0 {
704 return None;
705 }
706 let tr_a = self.trace();
708 let a_sq = self.powi(2);
709 let tr_a_sq = a_sq.trace();
710 let a_cb = a_sq.matmul(self);
711 let tr_a_cb = a_cb.trace();
712 let left =
713 (1.0 / 6.0) * (tr_a.powi(3) - 3.0 * tr_a * tr_a_sq + 2.0 * tr_a_cb) * Self::identity();
714 let center = (1.0 / 2.0) * *self * (tr_a.powi(2) - tr_a_sq);
715 Some(det.recip() * (left - center + a_sq * tr_a - a_cb))
716 }
717}
718
719impl<const N: usize, const M: usize> fmt::Display for Matrix<N, M> {
720 #[inline]
721 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
722 f.write_str(&format!(
723 "[{}]",
724 self.iter_rows()
725 .map(|row| {
726 format!(
727 "[{}]",
728 row.iter()
729 .map(ToString::to_string)
730 .collect::<Vec<_>>()
731 .join(", ")
732 )
733 })
734 .collect::<Vec<_>>()
735 .join(",\n ")
736 ))
737 }
738}
739impl Matrix<2, 2> {
740 #[must_use]
772 #[inline]
773 pub fn svd(&self) -> (Self, DiagonalMatrix<2>, Self) {
774 let a_plus_d = f64::midpoint(self[(0, 0)], self[(1, 1)]);
775
776 let a_minus_d = (self[(0, 0)] - self[(1, 1)]) / 2.0;
777 let b_plus_c = f64::midpoint(self[(0, 1)], self[(1, 0)]);
778 let b_minus_c = (self[(1, 0)] - self[(0, 1)]) / 2.0;
779 let (q, r) = (
780 (a_plus_d.powi(2) + b_minus_c.powi(2)).sqrt(),
781 (a_minus_d.powi(2) + b_plus_c.powi(2)).sqrt(),
782 );
783
784 let sy = q - r;
785 let sign_sy = sy.signum();
786
787 let (a1, a2) = (
788 f64::atan2(b_plus_c, a_minus_d),
789 f64::atan2(b_minus_c, a_plus_d),
790 );
791
792 let gamma = f64::midpoint(a1, a2);
793 let beta = (a2 - a1) / 2.0;
794
795 let (sr, cr) = beta.sin_cos();
796 let (sl, cl) = gamma.sin_cos();
797
798 let u = Matrix22 {
799 rows: [[cl, -sl], [sl, cl]],
800 };
801 let vt = Matrix22 {
802 rows: [[cr, -sr], [sr * sign_sy, cr * sign_sy]],
803 };
804
805 let singular_values = DiagonalMatrix::<2> {
806 elements: [q + r, sy.abs()],
807 };
808
809 (u, singular_values, vt)
810 }
811}
812
813impl Matrix<3, 3> {
814 #[must_use]
846 #[inline]
847 pub fn svd(&self) -> (Self, DiagonalMatrix<3>, Self) {
848 #[inline]
849 fn jacobi_rotation(p_idx: usize, q_idx: usize, s: &mut Matrix33, v: &mut Matrix33) {
850 let c_h = 2.0 * (s[(p_idx, p_idx)] - s[(q_idx, q_idx)]);
853 let s_h = s[(p_idx, q_idx)];
854 let gamma = 5.828_427_124_746_2; let b = gamma * s_h.powi(2) < c_h.powi(2);
857 let omega = (c_h.powi(2) + s_h.powi(2)).sqrt().recip();
858
859 let (c_h_res, s_h_res) = if b {
860 (omega * c_h, omega * s_h)
861 } else {
862 (C_STAR, S_STAR)
863 };
864
865 let c_h_sq = c_h_res.powi(2);
867 let s_h_sq = s_h_res.powi(2);
868 let scale = c_h_sq + s_h_sq;
869 let cos = (c_h_sq - s_h_sq) / scale;
870 let sin = (2.0 * c_h_res * s_h_res) / scale;
871
872 let mut q_mat = Matrix33::identity();
873 q_mat[(p_idx, p_idx)] = cos;
874 q_mat[(p_idx, q_idx)] = -sin;
875 q_mat[(q_idx, p_idx)] = sin;
876 q_mat[(q_idx, q_idx)] = cos;
877
878 *s = q_mat.transpose().matmul(s).matmul(&q_mat);
879 *v = v.matmul(&q_mat);
880 }
881
882 #[inline]
883 fn cond_neg_swap_rows(c: bool, rows: &mut [[f64; 3]], i: usize, j: usize) {
884 if c {
885 rows.swap(i, j);
886 rows[j].iter_mut().for_each(|x| *x *= -1.0);
887 }
888 }
889
890 #[inline]
891 fn qr_givens_rotation(p: usize, q: usize, r_mat: &mut Matrix33, u_mat: &mut Matrix33) {
892 let rho = (r_mat[(p, p)].powi(2) + r_mat[(q, p)].powi(2)).sqrt();
893 let c = if rho == 0.0 { 1.0 } else { r_mat[(p, p)] / rho };
894 let s_val = if rho == 0.0 { 0.0 } else { r_mat[(q, p)] / rho };
895
896 let mut q_t = Matrix33::identity();
897 q_t[(p, p)] = c;
898 q_t[(p, q)] = s_val;
899 q_t[(q, p)] = -s_val;
900 q_t[(q, q)] = c;
901
902 *r_mat = q_t.matmul(r_mat);
903 *u_mat = u_mat.matmul(&q_t.transpose());
904 }
905
906 const NUM_JACOBI_SWEEPS: usize = 6; let mut singular_values = self.transpose().matmul(self);
909 let mut v = Self::identity();
910
911 for _ in 0..NUM_JACOBI_SWEEPS {
912 jacobi_rotation(0, 1, &mut singular_values, &mut v);
913 jacobi_rotation(0, 2, &mut singular_values, &mut v);
914 jacobi_rotation(1, 2, &mut singular_values, &mut v);
915 }
916
917 let mut b = self.matmul(&v);
919 let mut b_cols = b.transpose().rows;
920 let mut v_cols = v.transpose().rows;
921 let mut rhos: [f64; 3] = std::array::from_fn(|i| b_cols[i].iter().map(|&x| x * x).sum());
922
923 if rhos[0] < rhos[1] {
924 rhos.swap(0, 1);
925 cond_neg_swap_rows(true, &mut b_cols, 0, 1);
926 cond_neg_swap_rows(true, &mut v_cols, 0, 1);
927 }
928 if rhos[1] < rhos[2] {
929 rhos.swap(1, 2);
930 cond_neg_swap_rows(true, &mut b_cols, 1, 2);
931 cond_neg_swap_rows(true, &mut v_cols, 1, 2);
932 }
933 if rhos[0] < rhos[1] {
934 rhos.swap(0, 1);
935 cond_neg_swap_rows(true, &mut b_cols, 0, 1);
936 cond_neg_swap_rows(true, &mut v_cols, 0, 1);
937 }
938
939 b = Matrix { rows: b_cols }.transpose();
940 v = Matrix { rows: v_cols }.transpose();
941
942 let mut r = b;
944 let mut u = Self::identity();
945
946 qr_givens_rotation(0, 1, &mut r, &mut u);
947 qr_givens_rotation(0, 2, &mut r, &mut u);
948 qr_givens_rotation(1, 2, &mut r, &mut u);
949
950 let mut sigma = r.diagonal();
952
953 if u.determinant() < 0.0 {
954 u.rows.iter_mut().for_each(|row| row[2] *= -1.0);
955 sigma[2] *= -1.0;
956 }
957
958 if v.determinant() < 0.0 {
959 v.rows.iter_mut().for_each(|row| row[2] *= -1.0);
960 sigma[2] *= -1.0;
961 }
962
963 (u, sigma, v.transpose())
964 }
965}
966macro_rules! impl_copy_for_m {
968 ($N:literal, $($M:literal),+) => { $(impl Copy for Matrix<$N, $M> {})+ };
969}
970macro_rules! impl_copy_for_n_m {
972 ($($N:literal),+) => { $(impl_copy_for_m!($N, 1, 2, 3, 4);)+ };
973}
974
975impl_copy_for_n_m!(1, 2, 3, 4);
976
977#[cfg(test)]
978mod tests {
979 use std::{fmt::Debug, ops::Index};
980
981 use super::*;
982 use crate::matrix::{Matrix, Matrix22, Matrix33, Matrix44};
983 use approxim::{assert_relative_eq, assert_ulps_eq, ulps_eq};
984
985 use faer::Mat;
986 use rstest::rstest;
987
988 const EPS: f64 = 1e-13;
989
990 fn fill_faer<const N: usize, const M: usize>(m: [[f64; M]; N]) -> Mat<f64> {
991 let mut faer_matrix = Mat::<f64>::zeros(N, M);
992 for (i, row) in m.iter().enumerate() {
993 for (j, el) in row.iter().enumerate() {
994 *faer_matrix.get_mut(i, j) = *el;
995 }
996 }
997 faer_matrix
998 }
999 fn fill_faer_column<const N: usize>(c: [f64; N]) -> Mat<f64> {
1000 let mut faer_matrix = Mat::<f64>::zeros(N, 1);
1001 for (i, el) in c.iter().enumerate() {
1002 *faer_matrix.get_mut(i, 0) = *el;
1003 }
1004 faer_matrix
1005 }
1006 fn assert_matrices_ulps_eq<
1007 const N: usize,
1008 const M: usize,
1009 T0: Index<(usize, usize), Output = f64> + Debug,
1010 T1: Index<(usize, usize), Output = f64> + Debug,
1011 >(
1012 m0: &T0,
1013 m1: &T1,
1014 ) {
1015 for i in 0..N {
1016 for j in 0..M {
1017 if !ulps_eq!(m0[(i, j)], m1[(i, j)], epsilon = EPS) {
1018 assert_ulps_eq!(m0[(i, j)], m1[(i, j)], epsilon = EPS);
1019 }
1020 }
1021 }
1022 }
1023 fn assert_diags_ulps_eq<const N: usize>(
1024 m0: &DiagonalMatrix<N>,
1025 m1: &impl std::ops::Index<usize, Output = f64>,
1026 ) {
1027 for i in 0..N {
1028 assert_ulps_eq!(m0[i], m1[i], epsilon = EPS);
1029 }
1030 }
1031 #[rstest(
1032 rows,
1033 case([[-9.0]]),
1034 case([[1.0, -2.0], [3.0, 4.0]]),
1035 case([[1.0, 2.0, 3.0], [0.0, 1.0, 4.0], [5.0, 6.0, 0.0]]),
1036 case([[2.0, 0.0, 1.0], [3.0, 9.0, 9.0], [5.0, 1.0, 1.0]]),
1037 case(Matrix::<4, 4>::identity().rows),
1038 case([
1039 [-10.0, 4.0, 3.0, 4.0],
1040 [300.0, 5.0, 6.0, 7.0],
1041 [3.0, 6.0, 8.0, 9.0],
1042 [4.0, 7.0, 9.0, 10.0]
1043 ]),
1044 case(Matrix::<5, 5>::full(3.6).diagonal().to_dense().rows),
1045 case(Matrix::<8, 8>::identity().rows),
1046 )]
1047 fn test_determinant<const N: usize>(rows: [[f64; N]; N]) {
1048 let matrix = Matrix { rows };
1049 let faer_matrix = fill_faer(rows);
1050
1051 let custom_det = matrix.determinant();
1052 let faer_det = faer_matrix.determinant();
1053
1054 assert_relative_eq!(custom_det, faer_det, max_relative = 1e-14);
1055 }
1056 #[rstest(
1057 a_rows, b_rows,
1058 case([[-9.0]], [[-9.0]]),
1059 case(
1060 [[1.0, -2.0], [3.0, 4.0]], [[0.0, 1.0], [1.0, 0.0]]
1061 ),
1062 case(
1063 [[1.0, 2.0, 3.0], [0.0, 1.0, 4.0], [5.0, 6.0, 0.0]],
1064 [[-2.0, 1.0, 0.0], [3.0, 0.0, 1.0], [1.0, 4.0, -1.0]]
1065 ),
1066 case(
1067 [[2.0, 0.0, 1.0], [3.0, 0.0, 0.0], [5.0, 1.0, 1.0]],
1068 [[1.0, 0.0, 2.0], [0.0, 1.0, 1.0], [4.0, 0.0, 0.0]]
1069 ),
1070 case(Matrix::<4, 4>::identity().rows, Matrix::<4, 4>::full(2.0).rows),
1071 case(Matrix::<5, 5>::full(3.6).diagonal().to_dense().rows, Matrix::<5, 5>::identity().rows),
1072 case(Matrix::<8, 8>::identity().rows, Matrix::<8, 8>::full(1.5).rows),
1073 )]
1074 fn test_matrix_multiply_square<const N: usize>(a_rows: [[f64; N]; N], b_rows: [[f64; N]; N]) {
1075 let a = Matrix { rows: a_rows };
1076 let b = Matrix { rows: b_rows };
1077
1078 let faer_a = fill_faer(a_rows);
1079 let faer_b = fill_faer(b_rows);
1080
1081 let custom_prod = a.matmul(&b);
1082 let faer_prod = faer_a * faer_b;
1083 assert_matrices_ulps_eq::<N, N, _, _>(&custom_prod, &faer_prod);
1084 }
1085
1086 #[rstest]
1087 #[case(
1088 [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]],
1089 [[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]],
1090 )]
1091 #[case(
1092 [[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]],
1093 [[2.0, 3.0, 4.0], [5.0, 6.0, 7.0]],
1094 )]
1095 #[case(
1096 [[1.0, 2.0]],
1097 [[3.0], [4.0]],
1098 )]
1099 #[case(
1100 [[2.0, 0.0, 1.0]],
1101 [[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]],
1102 )]
1103 fn test_rectangular_matrix_multiply<const M: usize, const K: usize, const N: usize>(
1104 #[case] a_rows: [[f64; M]; N],
1105 #[case] b_rows: [[f64; K]; M],
1106 ) {
1107 let a = Matrix { rows: a_rows };
1108 let b = Matrix { rows: b_rows };
1109
1110 let faer_a = fill_faer(a_rows);
1111 let faer_b = fill_faer(b_rows);
1112
1113 let custom_prod = a.matmul(&b);
1114 let faer_prod = faer_a * faer_b;
1115 assert_matrices_ulps_eq::<N, K, _, _>(&custom_prod, &faer_prod);
1116 }
1117
1118 #[rstest(
1119 rows,
1120 case::identity(Matrix22::identity().rows),
1121 case::mixed_sign([[1.0, -2.0], [3.0, 4.0]]),
1122 case::det_zero([[12.0, 2.0], [4.0, 0.0]]),
1123 case::large_range([[1000.0, 0.0], [0.0, 1e-4]]),
1124 case::jordan_block([[1.0, 1.0], [0.0, 1.0]]),
1125 case::full_ones(Matrix22::full(1.0).rows),
1126 case::shear([[1.0, 2.0], [0.0, 1.0]]),
1127 case::nilpotent([[0.0, 1.0], [0.0, 0.0]]),
1128 case::scaling([[2.0, 0.0], [0.0, 3.0]]),
1129 )]
1137 fn test_svd_2x2_faer(rows: [[f64; 2]; 2]) {
1138 let matrix = Matrix22 { rows };
1139 let (u, s, vt) = matrix.svd();
1140
1141 assert_matrices_ulps_eq::<2, 2, _, _>(&u.matmul(&s.to_dense()).matmul(&vt), &matrix);
1143
1144 let faer = fill_faer(rows);
1146 let faersvd = faer.svd().unwrap();
1147 let (mut faeru, faers, mut faerv) =
1148 (faersvd.U().to_owned(), faersvd.S(), faersvd.V().to_owned());
1149
1150 if faeru.determinant().signum() != u.determinant().signum() {
1151 faeru[(0, 1)] *= -1.0;
1152 faeru[(1, 1)] *= -1.0;
1153 }
1154 if faerv.determinant().signum() != vt.determinant().signum() {
1155 faerv[(0, 1)] *= -1.0;
1156 faerv[(1, 1)] *= -1.0;
1157 }
1158
1159 assert_matrices_ulps_eq::<2, 2, _, _>(&u, &faeru);
1160 assert_diags_ulps_eq(&s, &faers);
1161 assert_matrices_ulps_eq::<2, 2, _, _>(&vt, &faerv.transpose());
1163 }
1164
1165 #[rstest(
1166 rows,
1167 case::identity(Matrix22::identity().rows),
1168 case::mixed_sign([[1.0, -2.0], [3.0, 4.0]]),
1169 case::det_zero([[12.0, 2.0], [4.0, 0.0]]),
1170 case::large_range([[1000.0, 0.0], [0.0, 1e-4]]),
1171 case::jordan_block([[1.0, 1.0], [0.0, 1.0]]),
1172 case::full_ones(Matrix22::full(1.0).rows),
1173 case::shear([[1.0, 2.0], [0.0, 1.0]]),
1174 case::nilpotent([[0.0, 1.0], [0.0, 0.0]]),
1175 case::scaling([[2.0, 0.0], [0.0, 3.0]]),
1176 case::reflect([[0.0, -1.0], [1.0, 0.0]]), case::negative_identity((Matrix22::identity()*-1.0).rows),
1178 case::anti_diagonal([[0.0, 1.0], [1.0, 0.0]]),
1179 case::singular([[1.0, 2.0], [2.0, 4.0]]),
1180 )]
1181 fn test_svd_2x2_nalgebra(rows: [[f64; 2]; 2]) {
1182 let matrix = Matrix22 { rows };
1183 let (u, s, vt) = matrix.svd();
1184
1185 assert_matrices_ulps_eq::<2, 2, _, _>(&u.matmul(&s.to_dense()).matmul(&vt), &matrix);
1187
1188 let na = nalgebra::Matrix2::from(rows).transpose();
1190 let nasvd = na.svd(true, true);
1191 let (nau, nas, navt) = (nasvd.u.unwrap(), nasvd.singular_values, nasvd.v_t.unwrap());
1192
1193 assert_matrices_ulps_eq::<2, 2, _, _>(&u, &nau);
1194 assert_diags_ulps_eq::<2>(&s, &nas);
1195 assert_matrices_ulps_eq::<2, 2, _, _>(&vt, &navt);
1196 }
1197
1198 #[rstest(
1199 rows,
1200 case::identity(Matrix33::identity().rows),
1201 case::general([[1.0, 2.0, 3.0], [0.0, 1.0, 4.0], [5.0, 6.0, 0.0]]),
1202 case::symmetric([[1.0, 2.0, 3.0], [2.0, 5.0, 6.0], [3.0, 6.0, 9.0]]),
1203 case::near_singular([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.000_000_1]]),
1204 case::scaled((Matrix33::identity() * 10.0).rows),
1205 case::from_paper_repo([[0.433_807, 0.269_185, 0.543_034], [0.440_339, 0.443_024, 0.166_492], [0.793_913, 0.125_443, 0.730_333]]), )]
1207 fn test_svd_3x3_faer(rows: [[f64; 3]; 3]) {
1208 let matrix = Matrix33 { rows };
1209 let (u, s, vt) = matrix.svd();
1210
1211 let m_recon = u.matmul(&s).matmul(&vt);
1213 assert_matrices_ulps_eq::<3, 3, _, _>(&m_recon, &matrix);
1214
1215 assert_relative_eq!(u.determinant(), 1.0, epsilon = EPS);
1217 assert_relative_eq!(vt.transpose().determinant(), 1.0, epsilon = EPS);
1218 assert_matrices_ulps_eq::<3, 3, _, _>(&u.matmul(&u.transpose()), &Matrix33::identity());
1219 assert_matrices_ulps_eq::<3, 3, _, _>(&vt.matmul(&vt.transpose()), &Matrix33::identity());
1220
1221 let faer_mat = fill_faer(rows);
1223 let faersvd = faer_mat.svd().unwrap();
1224
1225 let faers = faersvd.S();
1226 assert_diags_ulps_eq(
1228 &DiagonalMatrix {
1229 elements: s.elements.map(f64::abs),
1230 },
1231 &faers,
1232 );
1233 }
1234
1235 #[test]
1236 fn test_transpose_2x2() {
1237 let rows = [[1.0, -2.0], [3.0, 4.0]];
1238 let matrix = Matrix::<2, 2> { rows };
1239 let faer_matrix = fill_faer(rows);
1240 let custom_transpose = matrix.transpose();
1241 let faer_transpose = faer_matrix.transpose();
1242 assert_matrices_ulps_eq::<2, 2, _, _>(&custom_transpose, &faer_transpose);
1243 }
1244
1245 #[test]
1246 fn test_transpose_2x3() {
1247 let rows = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
1248 let matrix = Matrix::<2, 3> { rows };
1249 let faer_matrix = fill_faer(rows);
1250 let custom_transpose = matrix.transpose();
1251 let faer_transpose = faer_matrix.transpose();
1252 assert_matrices_ulps_eq::<3, 2, _, _>(&custom_transpose, &faer_transpose);
1253 }
1254
1255 #[test]
1256 fn test_transpose_3x2() {
1257 let rows = [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
1258 let matrix = Matrix::<3, 2> { rows };
1259 let faer_matrix = fill_faer(rows);
1260 let custom_transpose = matrix.transpose();
1261 let faer_transpose = faer_matrix.transpose();
1262 assert_matrices_ulps_eq::<2, 3, _, _>(&custom_transpose, &faer_transpose);
1263 }
1264
1265 #[test]
1266 fn test_transpose_1x1() {
1267 let rows = [[-9.0]];
1268 let matrix = Matrix::<1, 1> { rows };
1269 assert_matrices_ulps_eq::<1, 1, _, _>(&matrix.transpose(), &matrix);
1270 }
1271
1272 #[test]
1273 fn test_general_matrix_methods() {
1274 let zeros = Matrix::<2, 3>::zeros();
1276 let full = Matrix::<2, 3>::full(7.5);
1277 for i in 0..2 {
1278 for j in 0..3 {
1279 assert_eq!(zeros[(i, j)], 0.0);
1280 assert_eq!(full[(i, j)], 7.5);
1281 }
1282 }
1283 }
1284
1285 #[test]
1286 fn test_square_matrix_methods() {
1287 let identity = Matrix::<3, 3>::identity();
1288 let expected = Matrix::<3, 3> {
1289 rows: [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]],
1290 };
1291 assert_matrices_ulps_eq::<3, 3, _, _>(&identity, &expected);
1292 }
1293
1294 #[test]
1295 fn test_diag_conversions() {
1296 let mat = Matrix::<3, 3> {
1297 rows: [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]],
1298 };
1299 let diag = mat.diagonal();
1300 let expected_diag = DiagonalMatrix {
1301 elements: [1.0, 5.0, 9.0],
1302 };
1303 assert_diags_ulps_eq(&diag, &expected_diag);
1304
1305 let from_diag = diag.to_dense();
1306 let expected_from_diag = Matrix {
1307 rows: [[1.0, 0.0, 0.0], [0.0, 5.0, 0.0], [0.0, 0.0, 9.0]],
1308 };
1309 assert_matrices_ulps_eq::<3, 3, _, _>(&from_diag, &expected_from_diag);
1310 }
1311
1312 #[rstest(
1313 rows, vars,
1314 case(
1315 [[1.0, 2.0], [3.0, 4.0]],
1316 [0.5, 1.5]
1317 ),
1318 case(
1319 [[2.0, 0.0, 1.0], [3.0, 0.0, 0.0], [5.0, 1.0, 1.0]],
1320 [1.0, 2.0, 3.0]
1321 ),
1322 case(
1323 [[-33.0, 2.0, 0.0, 1.0], [3.0, -45.0, 0.0, 0.0], [5.0, 0.0, 1.0, 1.0], [0.0, 0.0, 0.0, 1.0]],
1324 [1.0, 2.0, 3.0, 4.0]
1325 ),
1326 )]
1327 fn test_quadratic_form<const N: usize>(rows: [[f64; N]; N], vars: [f64; N]) {
1328 let matrix = Matrix { rows };
1329 let result = matrix.compute_quadratic_form(&vars);
1330 assert_relative_eq!(
1331 result,
1332 (fill_faer_column(vars).transpose() * fill_faer(rows) * fill_faer_column(vars))[(0, 0)],
1333 max_relative = 1e-14
1334 );
1335 }
1336
1337 #[rstest(
1338 rows,
1339 case([[1.0, -2.0], [3.0, 4.0]]),
1340 case([[10.0, 0.0], [0.0, 0.1]]),
1341 case([[1.0, 1.0], [0.0, 1.0]]),
1342 )]
1343 fn test_inverse_2x2(rows: [[f64; 2]; 2]) {
1344 let matrix = Matrix22 { rows };
1345 let inv_matrix = matrix.inverse().expect("invertible");
1346 let product = matrix.matmul(&inv_matrix);
1347 let identity = Matrix22::identity();
1348
1349 assert_matrices_ulps_eq::<2, 2, _, _>(&product, &identity);
1350 }
1351 #[rstest(
1352 rows,
1353 case(Matrix33::identity().rows),
1354 case([[1.0, -3.0, 4.5], [3.0, 4.0,5.0], [8.0, -9.3, 10.0]]),
1355 case([[2.0, 1.0, 0.0], [0.0, 1.0, 2.0], [1.0, 0.0, 1.0]]),
1356 case([[5.0, -2.0, 3.0], [1.0, 0.0, 4.0], [-1.0, 2.0, 1.0]])
1357 )]
1358 fn test_inverse_3x3(rows: [[f64; 3]; 3]) {
1359 let matrix = Matrix33 { rows };
1360 let inv_matrix = matrix.inverse().expect("invertible");
1361 let product = matrix.matmul(&inv_matrix);
1362 let identity = Matrix33::identity();
1363
1364 assert_matrices_ulps_eq::<3, 3, _, _>(&product, &identity);
1365 }
1366 #[rstest(
1367 rows,
1368 case(Matrix44::identity().rows),
1369 case([[1.0, -4.0, 4.5,1.0], [4.0, 4.0,5.0,0.0], [8.0, -9.4, 10.0,9.0], [-1.0,-1.0,1.0,1.0]]),
1370 )]
1371 fn test_inverse_4x4(rows: [[f64; 4]; 4]) {
1372 let matrix = Matrix44 { rows };
1373 let inv_matrix = matrix.inverse().expect("invertible");
1374 let product = matrix.matmul(&inv_matrix);
1375 let identity = Matrix44::identity();
1376
1377 assert_matrices_ulps_eq::<4, 4, _, _>(&product, &identity);
1378 }
1379}