hoomd_linear_algebra/
matrix.rs

1// Copyright (c) 2024-2026 The Regents of the University of Michigan.
2// Part of hoomd-rs, released under the BSD 3-Clause License.
3
4use serde::{Deserialize, Serialize};
5use serde_with::serde_as;
6use std::fmt;
7
8use crate::{Full, GeneralMatrix, Invertible, MatMul, QuadraticForm, SquareMatrix};
9
10/// ``std::ops`` implementations for [`Matrix`]
11mod ops;
12
13pub use crate::diagonal::DiagonalMatrix;
14
15/// A matrix with N rows and M columns, allocated on the stack.
16///
17/// # Example
18/// ```
19/// use hoomd_linear_algebra::matrix::Matrix;
20///
21/// let a = Matrix {
22///     rows: [[-1.0, 4.0, -6.0], [2.0, -3.0, 1.0]],
23/// };
24/// ```
25#[serde_as]
26#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
27pub struct Matrix<const N: usize, const M: usize> {
28    /// The elements of the matrix
29    #[serde_as(as = "[[_; M]; N]")]
30    pub rows: [[f64; M]; N],
31}
32/// A 2x2 matrix, allocated on the stack.
33pub type Matrix22 = Matrix<2, 2>;
34/// A 3x3 matrix, allocated on the stack.
35pub type Matrix33 = Matrix<3, 3>;
36/// A 4x4 matrix, allocated on the stack.
37pub type Matrix44 = Matrix<4, 4>;
38
39/// cos(π/8), used in svd 3x3
40const C_STAR: f64 = 0.923_879_532_511_286_8;
41/// sin(π/8), used in svd 3x3
42const S_STAR: f64 = 0.382_683_432_365_089_8;
43
44impl<const N: usize, const M: usize> GeneralMatrix for Matrix<N, M> {
45    /// Fill a matrix with zeros.
46    ///
47    /// # Examples
48    /// ```
49    /// use hoomd_linear_algebra::{GeneralMatrix, matrix::Matrix22};
50    ///
51    /// let m = Matrix22::zeros();
52    /// assert_eq!(m.rows, [[0.0, 0.0], [0.0, 0.0]]);
53    /// ```
54    #[inline]
55    fn zeros() -> Self {
56        Self {
57            rows: std::array::from_fn(|_| std::array::from_fn(|_| 0.0)),
58        }
59    }
60    #[inline]
61    fn shape(&self) -> (usize, usize) {
62        (self.n_rows(), self.n_columns())
63    }
64}
65
66impl<const N: usize, const M: usize> Full for Matrix<N, M> {
67    /// Construct a matrix with the same value in every element.
68    ///
69    /// # Examples
70    /// ```
71    /// use hoomd_linear_algebra::{Full, matrix::Matrix22};
72    ///
73    /// let m = Matrix22::full(5.0);
74    /// assert_eq!(m.rows, [[5.0, 5.0], [5.0, 5.0]]);
75    /// ```
76    #[inline]
77    fn full(value: f64) -> Self {
78        Self {
79            rows: std::array::from_fn(|_| std::array::from_fn(|_| value)),
80        }
81    }
82}
83
84impl<const N: usize> SquareMatrix for Matrix<N, N> {
85    /// # Examples
86    /// ```
87    /// use hoomd_linear_algebra::{SquareMatrix, matrix::Matrix22};
88    ///
89    /// let m = Matrix22::identity();
90    /// assert_eq!(m.rows, [[1.0, 0.0], [0.0, 1.0]]);
91    /// ```
92    #[inline]
93    fn identity() -> Self {
94        Self {
95            rows: std::array::from_fn(|i| std::array::from_fn(|j| if i == j { 1.0 } else { 0.0 })),
96        }
97    }
98}
99
100impl<const N: usize, const M: usize, const K: usize> MatMul<Matrix<M, K>> for Matrix<N, M> {
101    type Output = Matrix<N, K>;
102    /// Matrix-matrix multiplication.
103    ///
104    /// # Examples
105    /// ```
106    /// use hoomd_linear_algebra::{
107    ///     MatMul,
108    ///     matrix::{Matrix, Matrix22},
109    /// };
110    ///
111    /// let a = Matrix22 {
112    ///     rows: [[1.0, 2.0], [3.0, 4.0]],
113    /// };
114    /// let b = Matrix22 {
115    ///     rows: [[5.0, 6.0], [7.0, 8.0]],
116    /// };
117    /// let c = a.matmul(&b);
118    /// assert_eq!(c.rows, [[19.0, 22.0], [43.0, 50.0]]);
119    /// ```
120    #[inline]
121    fn matmul(&self, rhs: &Matrix<M, K>) -> Self::Output {
122        let mut result = Self::Output::zeros();
123        for n in 0..N {
124            for k in 0..K {
125                for m in 0..M {
126                    result.rows[n][k] += self.rows[n][m] * rhs.rows[m][k];
127                }
128            }
129        }
130
131        result
132    }
133}
134
135impl<const N: usize, const M: usize> MatMul<DiagonalMatrix<M>> for Matrix<N, M> {
136    type Output = Matrix<N, M>;
137
138    /// Matrix-diagonal matrix multiplication.
139    ///
140    /// This is equivalent to scaling each column of a [`Matrix`] by the corresponding
141    /// element in a [`DiagonalMatrix`].
142    ///
143    /// # Examples
144    /// ```
145    /// use hoomd_linear_algebra::{
146    ///     Full, GeneralMatrix, MatMul,
147    ///     matrix::{DiagonalMatrix, Matrix22},
148    /// };
149    ///
150    /// let diag = DiagonalMatrix {
151    ///     elements: [3.0, 4.0],
152    /// };
153    /// let a = Matrix22::full(1.0).matmul(&diag);
154    /// assert_eq!(a[(0, 1)], 4.0);
155    /// assert_eq!(a[(1, 0)], 3.0);
156    /// ```
157    #[inline]
158    fn matmul(&self, rhs: &DiagonalMatrix<M>) -> Self::Output {
159        let mut result = Self::Output::zeros();
160        for (i, row) in result.rows.iter_mut().enumerate().take(N) {
161            for j in 0..M {
162                row[j] = self.rows[i][j] * rhs[j];
163            }
164        }
165        result
166    }
167}
168
169impl<const N: usize, const M: usize> Matrix<N, M> {
170    /// Interchange the rows and columns of matrix.
171    ///
172    /// ```math
173    /// \mathbf{A}^\top_{ji} = \mathbf{A}_{ij}
174    /// ```
175    ///
176    /// # Examples
177    /// ```
178    /// use hoomd_linear_algebra::matrix::Matrix;
179    ///
180    /// let m = Matrix {
181    ///     rows: [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]],
182    /// };
183    /// let m_t = m.transpose();
184    /// assert_eq!(m_t.rows, [[1.0, 4.0], [2.0, 5.0], [3.0, 6.0]]);
185    /// ```
186    #[inline]
187    #[must_use]
188    pub fn transpose(&self) -> Matrix<M, N> {
189        Matrix {
190            rows: std::array::from_fn(|j| std::array::from_fn(|i| self[(i, j)])),
191        }
192    }
193
194    /// Apply a function to an [`Matrix`] by rows.
195    ///
196    /// # Examples
197    /// ```
198    /// use hoomd_linear_algebra::{Full, GeneralMatrix, matrix::Matrix22};
199    ///
200    /// let m = Matrix22::full(3.0);
201    /// assert_eq!(
202    ///     m.map_rows(|v| [v[0] + 2.0, v[1]]),
203    ///     Matrix22 {
204    ///         rows: [[5.0, 3.0], [5.0, 3.0]]
205    ///     }
206    /// );
207    /// ```
208    #[inline]
209    #[must_use]
210    pub fn map_rows<F>(self, f: F) -> Self
211    where
212        F: FnMut([f64; M]) -> [f64; M],
213    {
214        Self {
215            rows: self.rows.map(f),
216        }
217    }
218
219    /// Apply a function to an [`Matrix`] by columns.
220    ///
221    /// # Examples
222    /// ```
223    /// use hoomd_linear_algebra::{Full, GeneralMatrix, matrix::Matrix22};
224    ///
225    /// let m = Matrix22::full(3.0);
226    /// assert_eq!(
227    ///     m.map_columns(|v| [v[0] + 2.0, v[1]]),
228    ///     Matrix22 {
229    ///         rows: [[5.0, 5.0], [3.0, 3.0]]
230    ///     }
231    /// );
232    /// ```
233    #[inline]
234    #[must_use]
235    pub fn map_columns<F>(self, f: F) -> Self
236    where
237        F: FnMut([f64; N]) -> [f64; N],
238    {
239        self.clone().transpose().map_rows(f).transpose()
240    }
241    /// Apply a function to [`Matrix`] elements.
242    ///
243    /// # Examples
244    /// ```
245    /// use hoomd_linear_algebra::{Full, GeneralMatrix, matrix::Matrix33};
246    ///
247    /// let m = Matrix33::full(3.0);
248    /// assert_eq!(m.map_elements(|x| x + 2.0), m + Matrix33::full(2.0));
249    /// ```
250    #[inline]
251    #[must_use]
252    pub fn map_elements<F>(self, f: F) -> Self
253    where
254        F: Fn(f64) -> f64,
255    {
256        Self {
257            rows: self.rows.map(|v| v.map(&f)),
258        }
259    }
260
261    /// Iterate over every element in the [`Matrix`].
262    ///
263    /// The iterator yields all matrix elements in row major order.
264    ///
265    /// # Examples
266    /// ```
267    /// use hoomd_linear_algebra::{SquareMatrix, matrix::Matrix22};
268    ///
269    /// let x = Matrix22 {
270    ///     rows: [[1.0, 2.0], [3.0, 4.0]],
271    /// };
272    ///
273    /// let mut iterator = x.iter_elements();
274    /// assert_eq!(iterator.next(), Some(1.0));
275    /// assert_eq!(iterator.next(), Some(2.0));
276    /// assert_eq!(iterator.next(), Some(3.0));
277    /// assert_eq!(iterator.next(), Some(4.0));
278    /// assert_eq!(iterator.next(), None);
279    /// ```
280    #[inline]
281    pub fn iter_elements(&self) -> impl Iterator<Item = f64> {
282        self.rows.iter().flat_map(|row| row.iter().copied())
283    }
284
285    /// Iterate over every element in the [`Matrix`].
286    ///
287    /// The iterator yields all matrix elements in row major order.
288    ///
289    /// # Examples
290    /// ```
291    /// use hoomd_linear_algebra::{SquareMatrix, matrix::Matrix22};
292    ///
293    /// let mut x = Matrix22 {
294    ///     rows: [[1.0, 2.0], [3.0, 4.0]],
295    /// };
296    ///
297    /// let x_copy = x.clone();
298    /// let mut iterator = x.iter_elements_mut();
299    /// iterator.for_each(|x| *x *= 2.0);
300    /// assert_eq!(x, x_copy * 2.0);
301    /// ```
302    #[inline]
303    pub fn iter_elements_mut(&mut self) -> impl Iterator<Item = &mut f64> {
304        self.rows.iter_mut().flat_map(|row| row.iter_mut())
305    }
306
307    /// Iterate over matrix rows.
308    ///
309    /// # Examples
310    /// ```
311    /// use hoomd_linear_algebra::{SquareMatrix, matrix::Matrix22};
312    ///
313    /// let x = Matrix22 {
314    ///     rows: [[1.0, 2.0], [3.0, 4.0]],
315    /// };
316    /// let mut iterator = x.iter_rows();
317    /// assert_eq!(iterator.next(), Some([1.0, 2.0]));
318    /// assert_eq!(iterator.next(), Some([3.0, 4.0]));
319    /// assert_eq!(iterator.next(), None);
320    /// ```
321    #[inline]
322    pub fn iter_rows(&self) -> impl Iterator<Item = [f64; M]> {
323        self.rows.iter().copied()
324    }
325
326    /// Folds every element into an accumulator by applying an operation.
327    ///
328    /// `fold_elements` takes two arguments: an initial value, and a closure
329    /// with two arguments: an ‘accumulator’, and an element. The closure
330    /// returns the value that the accumulator should have for the next iteration.
331    ///
332    /// The initial value is the value the accumulator will have on the first call.
333    /// After applying this closure to every element of the flattened iterator,
334    /// `fold_elements` returns the accumulator.
335    ///
336    /// # Examples
337    /// ```
338    /// use hoomd_linear_algebra::{Full, GeneralMatrix, matrix::Matrix22};
339    ///
340    /// let m = Matrix22::full(3.0);
341    /// // Sum the elements of a matrix
342    /// assert_eq!(m.fold_elements(0.0, |acc, x| acc + x), 3.0 * 4.0);
343    /// ```
344    #[inline]
345    pub fn fold_elements<B, F>(self, init: B, mut f: F) -> B
346    where
347        F: FnMut(B, f64) -> B,
348    {
349        let mut accum = init;
350        for x in self.iter_elements() {
351            accum = f(accum, x);
352        }
353        accum
354    }
355
356    /// Folds every row into an accumulator by applying an operation.
357    ///
358    /// `fold_rows` takes two arguments: an initial value, and a closure with two
359    /// arguments: an ‘accumulator’, and an element. The closure returns the
360    /// value that the accumulator should have for the next iteration.
361    ///
362    /// The initial value is the value the accumulator will have on the first
363    /// call. After applying this closure to every element of the flattened
364    /// iterator, `fold_rows` returns the accumulator.
365    ///
366    /// # Examples
367    /// ```
368    /// use hoomd_linear_algebra::{GeneralMatrix, matrix::Matrix22};
369    ///
370    /// let m = Matrix22 {
371    ///     rows: [[1.0, 2.0], [3.0, 4.0]],
372    /// };
373    /// // Average the columns of a matrix
374    /// let n_rows = m.n_rows() as f64;
375    /// assert_eq!(
376    ///     m.fold_rows([0.0; 2], |acc, x| [acc[0] + x[0], acc[1] + x[1]])
377    ///         .map(|x| x / n_rows),
378    ///     [(1.0 + 3.0) / 2.0, (2.0 + 4.0) / 2.0]
379    /// );
380    /// ```
381    #[inline]
382    pub fn fold_rows<B, F>(self, init: B, mut f: F) -> B
383    where
384        F: FnMut(B, [f64; M]) -> B,
385    {
386        let mut accum = init;
387        for x in self.iter_rows() {
388            accum = f(accum, x);
389        }
390        accum
391    }
392
393    /// Get the number of rows in the [`Matrix`].
394    ///
395    /// # Examples
396    /// ```
397    /// use hoomd_linear_algebra::matrix::Matrix;
398    ///
399    /// let m = Matrix {
400    ///     rows: [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]],
401    /// };
402    /// assert_eq!(m.n_rows(), 2);
403    /// ```
404    #[must_use]
405    #[inline]
406    pub const fn n_rows(&self) -> usize {
407        N
408    }
409    /// Get the number of columns in the [`Matrix`].
410    ///
411    /// # Examples
412    /// ```
413    /// use hoomd_linear_algebra::matrix::Matrix;
414    ///
415    /// let m = Matrix {
416    ///     rows: [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]],
417    /// };
418    /// assert_eq!(m.n_columns(), 3);
419    /// ```
420    #[must_use]
421    #[inline]
422    pub const fn n_columns(&self) -> usize {
423        M
424    }
425}
426impl<const N: usize> Matrix<N, N> {
427    /// Construct a square matrix with the given diagonal.
428    ///
429    /// `diagonal[i]` is the matrix element $` A_{ii} `$.
430    /// All off-diagonal elements are 0.
431    ///
432    /// # Example
433    ///
434    /// ```
435    /// use hoomd_linear_algebra::matrix::Matrix;
436    ///
437    /// let a = Matrix::with_diagonal([2.0, -3.0]);
438    ///
439    /// assert_eq!(a.rows, [[2.0, 0.0], [0.0, -3.0]]);
440    /// ```
441    #[inline]
442    #[must_use]
443    pub fn with_diagonal(diagonal: [f64; N]) -> Self {
444        DiagonalMatrix { elements: diagonal }.to_dense()
445    }
446
447    /// Compute the signed hypervolume of the hyperparallelepiped defined by a matrix.
448    ///
449    /// This implementation uses the Laplace expansion, which is optimal for small
450    /// matrices but will be extremely slow for large matrices due to its $`O(N!)`$
451    /// complexity.
452    ///
453    /// # Example
454    ///
455    /// ```
456    /// use hoomd_linear_algebra::{SquareMatrix, matrix::Matrix22};
457    ///
458    /// let identity = Matrix22::identity();
459    /// assert_eq!(identity.determinant(), 1.0);
460    ///
461    /// let scaled = identity * 2.0;
462    /// assert_eq!(scaled.determinant(), 2.0 * 2.0);
463    /// ```
464    #[must_use]
465    #[inline]
466    pub fn determinant(&self) -> f64 {
467        // Compute the determinant of a 2x2 minor.
468        #[inline]
469        fn det2(a: f64, b: f64, c: f64, d: f64) -> f64 {
470            a * d - b * c
471        }
472        // Because math with const generics is not allowed in rust, we compute the indices
473        // of each submatrix and recur on those noncontiguous segments of the input.
474        #[inline]
475        fn det_recursive_noslice<const N: usize>(
476            matrix: &Matrix<N, N>,
477            row: usize,
478            col_indices: [usize; N],
479            minor_size: usize,
480        ) -> f64 {
481            // If we recur any lower than 4x4 minors, performance drops dramatically
482            if minor_size == 4 {
483                let r = matrix.rows;
484                let c = col_indices;
485
486                // Map recursive indices to direct matrix indices
487                let (i0, i1, i2, i3) = (row, row + 1, row + 2, row + 3);
488                let [j0, j1, j2, j3] = c[..4] else {
489                    unreachable!() // N >= 4 if we reach this point
490                };
491
492                let m0 = det2(r[i2][j2], r[i2][j3], r[i3][j2], r[i3][j3]);
493                let m1 = det2(r[i2][j1], r[i2][j3], r[i3][j1], r[i3][j3]);
494                let m2 = det2(r[i2][j1], r[i2][j2], r[i3][j1], r[i3][j2]);
495                let m3 = det2(r[i2][j0], r[i2][j3], r[i3][j0], r[i3][j3]);
496                let m4 = det2(r[i2][j0], r[i2][j2], r[i3][j0], r[i3][j2]);
497                let m5 = det2(r[i2][j0], r[i2][j1], r[i3][j0], r[i3][j1]);
498
499                return r[i0][j0] * (r[i1][j1] * m0 - r[i1][j2] * m1 + r[i1][j3] * m2)
500                    - r[i0][j1] * (r[i1][j0] * m0 - r[i1][j2] * m3 + r[i1][j3] * m4)
501                    + r[i0][j2] * (r[i1][j0] * m1 - r[i1][j1] * m3 + r[i1][j3] * m5)
502                    - r[i0][j3] * (r[i1][j0] * m2 - r[i1][j1] * m4 + r[i1][j2] * m5);
503            }
504
505            (0..minor_size).fold(0.0, |acc, idx| {
506                let minor_size = minor_size - 1;
507                let mut minor_cols = [0; N];
508                for j in 0..minor_size {
509                    // Store the indices for the next recursion, skipping col idx
510                    minor_cols[j] = col_indices[j + usize::from(j >= idx)];
511                }
512
513                let sign = if idx % 2 == 0 { 1.0 } else { -1.0 };
514                acc + sign
515                    * matrix.rows[row][col_indices[idx]]
516                    * det_recursive_noslice(matrix, row + 1, minor_cols, minor_size)
517            })
518        }
519        // Early exit for small matrices to ensure we get the optimal code.
520        match N {
521            0 => return 0.0,
522            1 => return self[(0, 0)],
523            2 => return det2(self[(0, 0)], self[(1, 0)], self[(0, 1)], self[(1, 1)]),
524            3 => {
525                return self[(0, 0)] * det2(self[(1, 1)], self[(1, 2)], self[(2, 1)], self[(2, 2)])
526                    - self[(0, 1)] * det2(self[(1, 0)], self[(1, 2)], self[(2, 0)], self[(2, 2)])
527                    + self[(0, 2)] * det2(self[(1, 0)], self[(1, 1)], self[(2, 0)], self[(2, 1)]);
528            }
529            _ => (),
530        }
531
532        let col_indices = std::array::from_fn(|i| i);
533        det_recursive_noslice(self, 0, col_indices, N)
534    }
535
536    /// Compute the sum of diagonal elements of a square matrix.
537    ///
538    /// # Examples
539    ///
540    /// ```
541    /// use hoomd_linear_algebra::{SquareMatrix, matrix::Matrix22};
542    ///
543    /// let identity = Matrix22::identity();
544    /// assert_eq!(identity.trace(), 2.0);
545    ///
546    /// let scaled = identity * 3.0;
547    /// assert_eq!(scaled.trace(), 3.0 + 3.0);
548    /// ```
549    #[must_use]
550    #[inline]
551    pub fn trace(&self) -> f64 {
552        std::array::from_fn::<_, N, _>(|i| self[(i, i)])
553            .iter()
554            .sum()
555    }
556
557    /// Compute a matrix to an integer power
558    ///
559    /// ```math
560    /// \mathbf{A}^n = \prod_{i=1}^n \mathbf{A}
561    /// ```
562    ///
563    /// # Examples
564    ///
565    /// ```
566    /// use hoomd_linear_algebra::{Full, GeneralMatrix, MatMul, matrix::Matrix22};
567    ///
568    /// let matrix = Matrix22::full(2.0);
569    ///
570    /// assert_eq!(matrix.powi(2), matrix.matmul(&matrix));
571    ///
572    /// assert_eq!(matrix.powi(2).powi(2), matrix.powi(4));
573    /// ```
574    #[must_use]
575    #[inline]
576    pub fn powi(&self, n: u32) -> Self {
577        (0..n).fold(Self::identity(), |acc, _| acc.matmul(self))
578    }
579
580    /// Extract the diagonal elements from a square matrix.
581    ///
582    /// This method returns a `DiagonalMatrix<N>` containing the diagonal elements
583    /// of the input matrix, where the element at position `(i, i)` is taken from
584    /// the input matrix. All off-diagonal elements are ignored.
585    ///
586    /// # Examples
587    /// ```
588    /// use hoomd_linear_algebra::matrix::Matrix33;
589    /// let a = Matrix33 {
590    ///     rows: [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]],
591    /// };
592    /// let b = a.diagonal();
593    /// assert_eq!(b.elements, [1.0, 5.0, 9.0]);
594    /// ```
595    #[must_use]
596    #[inline]
597    pub fn diagonal(&self) -> DiagonalMatrix<N> {
598        DiagonalMatrix {
599            elements: std::array::from_fn(|i| self.rows[i][i]),
600        }
601    }
602}
603
604impl<const N: usize> QuadraticForm<N> for Matrix<N, N> {
605    #[inline]
606    fn compute_quadratic_form(&self, x: &[f64; N]) -> f64 {
607        let mut result = 0.0;
608
609        for i in 0..N {
610            for j in 0..N {
611                result += x[i] * self[(i, j)] * x[j];
612            }
613        }
614        result
615    }
616}
617
618impl Invertible for Matrix<2, 2> {
619    /// Compute the inverse of a matrix. Will be `None` if the matrix is not invertible.
620    ///
621    /// This implementation uses a closed form solution for the matrix inverse.
622    ///
623    /// # Examples
624    /// ```
625    /// use hoomd_linear_algebra::{Invertible, matrix::Matrix22};
626    ///
627    /// let m = Matrix22 {
628    ///     rows: [[1.0, 2.0], [3.0, 4.0]],
629    /// };
630    /// let m_inv = m.inverse().unwrap();
631    /// assert_eq!(m_inv.rows, [[-2.0, 1.0], [1.5, -0.5]]);
632    /// ```
633    #[inline]
634    fn inverse(&self) -> Option<Self> {
635        let det = self.determinant();
636        if det == 0.0 {
637            None
638        } else {
639            let inv_det = det.recip();
640            Some(Self {
641                rows: [
642                    [inv_det * self.rows[1][1], inv_det * -self.rows[0][1]],
643                    [inv_det * -self.rows[1][0], inv_det * self.rows[0][0]],
644                ],
645            })
646        }
647    }
648}
649
650impl Invertible for Matrix<3, 3> {
651    /// Compute the inverse of a matrix. Will be `None` if the matrix is not invertible.
652    ///
653    /// This implementation uses a closed form solution for the matrix inverse based on
654    /// the cross product of rows.
655    ///
656    /// # Example
657    /// ```
658    /// use hoomd_linear_algebra::{Invertible, SquareMatrix, matrix::Matrix};
659    ///
660    /// let m = Matrix::identity() * 5.0;
661    /// let m_inv = m.inverse();
662    ///
663    /// assert_eq!(m_inv, Some(Matrix::with_diagonal([1.0 / 5.0; 3])));
664    /// ```
665    #[inline]
666    fn inverse(&self) -> Option<Self> {
667        #[inline]
668        fn cross(u: [f64; 3], v: [f64; 3]) -> [f64; 3] {
669            [
670                u[1] * v[2] - u[2] * v[1],
671                u[2] * v[0] - u[0] * v[2],
672                u[0] * v[1] - u[1] * v[0],
673            ]
674        }
675        let [x0, x1, x2] = self.rows;
676        let det = self.determinant();
677        if det == 0.0 {
678            return None;
679        }
680        let rows = [cross(x1, x2), cross(x2, x0), cross(x0, x1)];
681        Some(det.recip() * Self { rows }.transpose())
682    }
683}
684impl Invertible for Matrix<4, 4> {
685    /// Compute the inverse of a matrix. Will be `None` if the matrix is not invertible.
686    ///
687    /// This implementation uses a closed form solution for the matrix inverse based on
688    /// the Cayley–Hamilton method.
689    ///
690    /// # Examples
691    /// ```
692    /// use hoomd_linear_algebra::{
693    ///     Invertible, MatMul, SquareMatrix, matrix::Matrix44,
694    /// };
695    ///
696    /// let m = Matrix44::identity();
697    /// let m_inv = m.inverse().unwrap();
698    /// assert_eq!(m_inv.rows, m.rows);
699    /// ```
700    #[inline]
701    fn inverse(&self) -> Option<Self> {
702        let det = self.determinant();
703        if det == 0.0 {
704            return None;
705        }
706        // Compute components of Cayley–Hamilton factorization
707        let tr_a = self.trace();
708        let a_sq = self.powi(2);
709        let tr_a_sq = a_sq.trace();
710        let a_cb = a_sq.matmul(self);
711        let tr_a_cb = a_cb.trace();
712        let left =
713            (1.0 / 6.0) * (tr_a.powi(3) - 3.0 * tr_a * tr_a_sq + 2.0 * tr_a_cb) * Self::identity();
714        let center = (1.0 / 2.0) * *self * (tr_a.powi(2) - tr_a_sq);
715        Some(det.recip() * (left - center + a_sq * tr_a - a_cb))
716    }
717}
718
719impl<const N: usize, const M: usize> fmt::Display for Matrix<N, M> {
720    #[inline]
721    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
722        f.write_str(&format!(
723            "[{}]",
724            self.iter_rows()
725                .map(|row| {
726                    format!(
727                        "[{}]",
728                        row.iter()
729                            .map(ToString::to_string)
730                            .collect::<Vec<_>>()
731                            .join(", ")
732                    )
733                })
734                .collect::<Vec<_>>()
735                .join(",\n ")
736        ))
737    }
738}
739impl Matrix<2, 2> {
740    /// Decompose a [`Matrix22`] into a rotation, a scaling, and a second rotation.
741    ///
742    /// ```math
743    /// \mathbf{A} = \mathbf{U} \boldsymbol{\Sigma} \mathbf{V}^\intercal
744    /// ```
745    /// This implementation is based on the math in [Blinn 1996], and
746    /// ensures good (but not optimal) numerical stability. For certain
747    /// pathological inputs, preconditioning the matrix could provide a benefit
748    /// in numerical stability.
749    ///
750    /// `svd` sets all singular values to be positive.
751    ///
752    /// [Blinn 1996]: https://doi.org/10.1109/38.486688
753    ///
754    /// # Examples
755    /// ```
756    /// use hoomd_linear_algebra::{
757    ///     MatMul,
758    ///     matrix::{DiagonalMatrix, Matrix22},
759    /// };
760    /// let m = Matrix22 {
761    ///     rows: [[1.0, 2.0], [3.0, 4.0]],
762    /// };
763    /// let (u, s, vt) = m.svd();
764    /// let m_recon = u.matmul(&s.to_dense()).matmul(&vt);
765    /// for i in 0..2 {
766    ///     for j in 0..2 {
767    ///         assert!((m.rows[i][j] - m_recon.rows[i][j]).abs() < 1e-9);
768    ///     }
769    /// }
770    /// ```
771    #[must_use]
772    #[inline]
773    pub fn svd(&self) -> (Self, DiagonalMatrix<2>, Self) {
774        let a_plus_d = f64::midpoint(self[(0, 0)], self[(1, 1)]);
775
776        let a_minus_d = (self[(0, 0)] - self[(1, 1)]) / 2.0;
777        let b_plus_c = f64::midpoint(self[(0, 1)], self[(1, 0)]);
778        let b_minus_c = (self[(1, 0)] - self[(0, 1)]) / 2.0;
779        let (q, r) = (
780            (a_plus_d.powi(2) + b_minus_c.powi(2)).sqrt(),
781            (a_minus_d.powi(2) + b_plus_c.powi(2)).sqrt(),
782        );
783
784        let sy = q - r;
785        let sign_sy = sy.signum();
786
787        let (a1, a2) = (
788            f64::atan2(b_plus_c, a_minus_d),
789            f64::atan2(b_minus_c, a_plus_d),
790        );
791
792        let gamma = f64::midpoint(a1, a2);
793        let beta = (a2 - a1) / 2.0;
794
795        let (sr, cr) = beta.sin_cos();
796        let (sl, cl) = gamma.sin_cos();
797
798        let u = Matrix22 {
799            rows: [[cl, -sl], [sl, cl]],
800        };
801        let vt = Matrix22 {
802            rows: [[cr, -sr], [sr * sign_sy, cr * sign_sy]],
803        };
804
805        let singular_values = DiagonalMatrix::<2> {
806            elements: [q + r, sy.abs()],
807        };
808
809        (u, singular_values, vt)
810    }
811}
812
813impl Matrix<3, 3> {
814    /// Decompose a [`Matrix33`] into a rotation, a scaling, and a second rotation.
815    ///
816    /// ```math
817    /// \mathbf{A} = \mathbf{U} \boldsymbol{\Sigma} \mathbf{V}^\top
818    /// ```
819    /// This implementation is based on the method described by [McAdams 2011], which
820    ///  is a fast variant of the Jacobi iteration method.
821    ///
822    /// The method ensures that U and V are pure rotation matrices (determinant = 1).
823    /// As a result, the third singular value may be negative. For a conventional
824    /// SVD with non-negative singular values, the sign can be absorbed into U.
825    ///
826    /// [McAdams 2011]: https://digital.library.wisc.edu/1793/60736
827    ///
828    /// # Examples
829    /// ```
830    /// use hoomd_linear_algebra::{
831    ///     MatMul, SquareMatrix,
832    ///     matrix::{DiagonalMatrix, Matrix33},
833    /// };
834    /// let m = Matrix33 {
835    ///     rows: [[1.0, 2.0, 3.0], [0.0, 1.0, 4.0], [5.0, 6.0, 0.0]],
836    /// };
837    /// let (u, s, vt) = m.svd();
838    /// let m_recon = u.matmul(&s.to_dense()).matmul(&vt);
839    /// for i in 0..3 {
840    ///     for j in 0..3 {
841    ///         assert!((m.rows[i][j] - m_recon.rows[i][j]).abs() < 1e-9);
842    ///     }
843    /// }
844    /// ```
845    #[must_use]
846    #[inline]
847    pub fn svd(&self) -> (Self, DiagonalMatrix<3>, Self) {
848        #[inline]
849        fn jacobi_rotation(p_idx: usize, q_idx: usize, s: &mut Matrix33, v: &mut Matrix33) {
850            // ApproxGivensQuaternion adapted to build rotation matrix
851
852            let c_h = 2.0 * (s[(p_idx, p_idx)] - s[(q_idx, q_idx)]);
853            let s_h = s[(p_idx, q_idx)];
854            let gamma = 5.828_427_124_746_2; // 3 + 2 * sqrt(2)
855
856            let b = gamma * s_h.powi(2) < c_h.powi(2);
857            let omega = (c_h.powi(2) + s_h.powi(2)).sqrt().recip();
858
859            let (c_h_res, s_h_res) = if b {
860                (omega * c_h, omega * s_h)
861            } else {
862                (C_STAR, S_STAR)
863            };
864
865            // Build normalized rotation matrix from quaternion components
866            let c_h_sq = c_h_res.powi(2);
867            let s_h_sq = s_h_res.powi(2);
868            let scale = c_h_sq + s_h_sq;
869            let cos = (c_h_sq - s_h_sq) / scale;
870            let sin = (2.0 * c_h_res * s_h_res) / scale;
871
872            let mut q_mat = Matrix33::identity();
873            q_mat[(p_idx, p_idx)] = cos;
874            q_mat[(p_idx, q_idx)] = -sin;
875            q_mat[(q_idx, p_idx)] = sin;
876            q_mat[(q_idx, q_idx)] = cos;
877
878            *s = q_mat.transpose().matmul(s).matmul(&q_mat);
879            *v = v.matmul(&q_mat);
880        }
881
882        #[inline]
883        fn cond_neg_swap_rows(c: bool, rows: &mut [[f64; 3]], i: usize, j: usize) {
884            if c {
885                rows.swap(i, j);
886                rows[j].iter_mut().for_each(|x| *x *= -1.0);
887            }
888        }
889
890        #[inline]
891        fn qr_givens_rotation(p: usize, q: usize, r_mat: &mut Matrix33, u_mat: &mut Matrix33) {
892            let rho = (r_mat[(p, p)].powi(2) + r_mat[(q, p)].powi(2)).sqrt();
893            let c = if rho == 0.0 { 1.0 } else { r_mat[(p, p)] / rho };
894            let s_val = if rho == 0.0 { 0.0 } else { r_mat[(q, p)] / rho };
895
896            let mut q_t = Matrix33::identity();
897            q_t[(p, p)] = c;
898            q_t[(p, q)] = s_val;
899            q_t[(q, p)] = -s_val;
900            q_t[(q, q)] = c;
901
902            *r_mat = q_t.matmul(r_mat);
903            *u_mat = u_mat.matmul(&q_t.transpose());
904        }
905
906        // Symmetric Eigenanalysis of A^\top * A using Jacobi iterations
907        const NUM_JACOBI_SWEEPS: usize = 6; // Paper suggests 4, we want more accuracy
908        let mut singular_values = self.transpose().matmul(self);
909        let mut v = Self::identity();
910
911        for _ in 0..NUM_JACOBI_SWEEPS {
912            jacobi_rotation(0, 1, &mut singular_values, &mut v);
913            jacobi_rotation(0, 2, &mut singular_values, &mut v);
914            jacobi_rotation(1, 2, &mut singular_values, &mut v);
915        }
916
917        // Sort singular values and vectors
918        let mut b = self.matmul(&v);
919        let mut b_cols = b.transpose().rows;
920        let mut v_cols = v.transpose().rows;
921        let mut rhos: [f64; 3] = std::array::from_fn(|i| b_cols[i].iter().map(|&x| x * x).sum());
922
923        if rhos[0] < rhos[1] {
924            rhos.swap(0, 1);
925            cond_neg_swap_rows(true, &mut b_cols, 0, 1);
926            cond_neg_swap_rows(true, &mut v_cols, 0, 1);
927        }
928        if rhos[1] < rhos[2] {
929            rhos.swap(1, 2);
930            cond_neg_swap_rows(true, &mut b_cols, 1, 2);
931            cond_neg_swap_rows(true, &mut v_cols, 1, 2);
932        }
933        if rhos[0] < rhos[1] {
934            rhos.swap(0, 1);
935            cond_neg_swap_rows(true, &mut b_cols, 0, 1);
936            cond_neg_swap_rows(true, &mut v_cols, 0, 1);
937        }
938
939        b = Matrix { rows: b_cols }.transpose();
940        v = Matrix { rows: v_cols }.transpose();
941
942        // QR Decomposition of B = U * Sigma
943        let mut r = b;
944        let mut u = Self::identity();
945
946        qr_givens_rotation(0, 1, &mut r, &mut u);
947        qr_givens_rotation(0, 2, &mut r, &mut u);
948        qr_givens_rotation(1, 2, &mut r, &mut u);
949
950        // Enforce conventions for outputs. Rotations are proper and S can be negative
951        let mut sigma = r.diagonal();
952
953        if u.determinant() < 0.0 {
954            u.rows.iter_mut().for_each(|row| row[2] *= -1.0);
955            sigma[2] *= -1.0;
956        }
957
958        if v.determinant() < 0.0 {
959            v.rows.iter_mut().for_each(|row| row[2] *= -1.0);
960            sigma[2] *= -1.0;
961        }
962
963        (u, sigma, v.transpose())
964    }
965}
966/// Macro to generate impls for a given row size `N` and multiple column sizes `M`.
967macro_rules! impl_copy_for_m {
968    ($N:literal, $($M:literal),+) => { $(impl Copy for Matrix<$N, $M> {})+ };
969}
970/// Implement Copy for matrices of an input size `N`, `M`
971macro_rules! impl_copy_for_n_m {
972    ($($N:literal),+) => { $(impl_copy_for_m!($N, 1, 2, 3, 4);)+ };
973}
974
975impl_copy_for_n_m!(1, 2, 3, 4);
976
977#[cfg(test)]
978mod tests {
979    use std::{fmt::Debug, ops::Index};
980
981    use super::*;
982    use crate::matrix::{Matrix, Matrix22, Matrix33, Matrix44};
983    use approxim::{assert_relative_eq, assert_ulps_eq, ulps_eq};
984
985    use faer::Mat;
986    use rstest::rstest;
987
988    const EPS: f64 = 1e-13;
989
990    fn fill_faer<const N: usize, const M: usize>(m: [[f64; M]; N]) -> Mat<f64> {
991        let mut faer_matrix = Mat::<f64>::zeros(N, M);
992        for (i, row) in m.iter().enumerate() {
993            for (j, el) in row.iter().enumerate() {
994                *faer_matrix.get_mut(i, j) = *el;
995            }
996        }
997        faer_matrix
998    }
999    fn fill_faer_column<const N: usize>(c: [f64; N]) -> Mat<f64> {
1000        let mut faer_matrix = Mat::<f64>::zeros(N, 1);
1001        for (i, el) in c.iter().enumerate() {
1002            *faer_matrix.get_mut(i, 0) = *el;
1003        }
1004        faer_matrix
1005    }
1006    fn assert_matrices_ulps_eq<
1007        const N: usize,
1008        const M: usize,
1009        T0: Index<(usize, usize), Output = f64> + Debug,
1010        T1: Index<(usize, usize), Output = f64> + Debug,
1011    >(
1012        m0: &T0,
1013        m1: &T1,
1014    ) {
1015        for i in 0..N {
1016            for j in 0..M {
1017                if !ulps_eq!(m0[(i, j)], m1[(i, j)], epsilon = EPS) {
1018                    assert_ulps_eq!(m0[(i, j)], m1[(i, j)], epsilon = EPS);
1019                }
1020            }
1021        }
1022    }
1023    fn assert_diags_ulps_eq<const N: usize>(
1024        m0: &DiagonalMatrix<N>,
1025        m1: &impl std::ops::Index<usize, Output = f64>,
1026    ) {
1027        for i in 0..N {
1028            assert_ulps_eq!(m0[i], m1[i], epsilon = EPS);
1029        }
1030    }
1031    #[rstest(
1032        rows,
1033        case([[-9.0]]),
1034        case([[1.0, -2.0], [3.0, 4.0]]),
1035        case([[1.0, 2.0, 3.0], [0.0, 1.0, 4.0], [5.0, 6.0, 0.0]]),
1036        case([[2.0, 0.0, 1.0], [3.0, 9.0, 9.0], [5.0, 1.0, 1.0]]),
1037        case(Matrix::<4, 4>::identity().rows),
1038        case([
1039            [-10.0, 4.0, 3.0, 4.0],
1040            [300.0, 5.0, 6.0, 7.0],
1041            [3.0, 6.0, 8.0, 9.0],
1042            [4.0, 7.0, 9.0, 10.0]
1043        ]),
1044        case(Matrix::<5, 5>::full(3.6).diagonal().to_dense().rows),
1045        case(Matrix::<8, 8>::identity().rows),
1046    )]
1047    fn test_determinant<const N: usize>(rows: [[f64; N]; N]) {
1048        let matrix = Matrix { rows };
1049        let faer_matrix = fill_faer(rows);
1050
1051        let custom_det = matrix.determinant();
1052        let faer_det = faer_matrix.determinant();
1053
1054        assert_relative_eq!(custom_det, faer_det, max_relative = 1e-14);
1055    }
1056    #[rstest(
1057        a_rows, b_rows,
1058        case([[-9.0]], [[-9.0]]),
1059        case(
1060            [[1.0, -2.0], [3.0, 4.0]], [[0.0, 1.0], [1.0, 0.0]]
1061        ),
1062        case(
1063            [[1.0, 2.0, 3.0], [0.0, 1.0, 4.0], [5.0, 6.0, 0.0]],
1064            [[-2.0, 1.0, 0.0], [3.0, 0.0, 1.0], [1.0, 4.0, -1.0]]
1065        ),
1066        case(
1067            [[2.0, 0.0, 1.0], [3.0, 0.0, 0.0], [5.0, 1.0, 1.0]],
1068            [[1.0, 0.0, 2.0], [0.0, 1.0, 1.0], [4.0, 0.0, 0.0]]
1069        ),
1070        case(Matrix::<4, 4>::identity().rows, Matrix::<4, 4>::full(2.0).rows),
1071        case(Matrix::<5, 5>::full(3.6).diagonal().to_dense().rows, Matrix::<5, 5>::identity().rows),
1072        case(Matrix::<8, 8>::identity().rows, Matrix::<8, 8>::full(1.5).rows),
1073    )]
1074    fn test_matrix_multiply_square<const N: usize>(a_rows: [[f64; N]; N], b_rows: [[f64; N]; N]) {
1075        let a = Matrix { rows: a_rows };
1076        let b = Matrix { rows: b_rows };
1077
1078        let faer_a = fill_faer(a_rows);
1079        let faer_b = fill_faer(b_rows);
1080
1081        let custom_prod = a.matmul(&b);
1082        let faer_prod = faer_a * faer_b;
1083        assert_matrices_ulps_eq::<N, N, _, _>(&custom_prod, &faer_prod);
1084    }
1085
1086    #[rstest]
1087    #[case(
1088        [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]],
1089        [[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]],
1090    )]
1091    #[case(
1092        [[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]],
1093        [[2.0, 3.0, 4.0], [5.0, 6.0, 7.0]],
1094    )]
1095    #[case(
1096        [[1.0, 2.0]],
1097        [[3.0], [4.0]],
1098    )]
1099    #[case(
1100        [[2.0, 0.0, 1.0]],
1101        [[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]],
1102    )]
1103    fn test_rectangular_matrix_multiply<const M: usize, const K: usize, const N: usize>(
1104        #[case] a_rows: [[f64; M]; N],
1105        #[case] b_rows: [[f64; K]; M],
1106    ) {
1107        let a = Matrix { rows: a_rows };
1108        let b = Matrix { rows: b_rows };
1109
1110        let faer_a = fill_faer(a_rows);
1111        let faer_b = fill_faer(b_rows);
1112
1113        let custom_prod = a.matmul(&b);
1114        let faer_prod = faer_a * faer_b;
1115        assert_matrices_ulps_eq::<N, K, _, _>(&custom_prod, &faer_prod);
1116    }
1117
1118    #[rstest(
1119        rows,
1120        case::identity(Matrix22::identity().rows),
1121        case::mixed_sign([[1.0, -2.0], [3.0, 4.0]]),
1122        case::det_zero([[12.0, 2.0], [4.0, 0.0]]),
1123        case::large_range([[1000.0, 0.0], [0.0, 1e-4]]),
1124        case::jordan_block([[1.0, 1.0], [0.0, 1.0]]),
1125        case::full_ones(Matrix22::full(1.0).rows),
1126        case::shear([[1.0, 2.0], [0.0, 1.0]]),
1127        case::nilpotent([[0.0, 1.0], [0.0, 0.0]]),
1128        case::scaling([[2.0, 0.0], [0.0, 3.0]]),
1129        // The fast algorithm does not compute the correct result for these degenerate
1130        // cases. test_svd_2x2_nalgebra verifies we reproduce the result expected for
1131        // this algorithm.
1132        // case::reflect([[0.0, -1.0], [1.0, 0.0]]),
1133        // case::negative_identity((Matrix22::identity()*-1.0).rows),
1134        // case::anti_diagonal([[0.0, 1.0], [1.0, 0.0]]),
1135        // case::singular([[1.0, 2.0], [2.0, 4.0]]),
1136    )]
1137    fn test_svd_2x2_faer(rows: [[f64; 2]; 2]) {
1138        let matrix = Matrix22 { rows };
1139        let (u, s, vt) = matrix.svd();
1140
1141        // Verify we can rebuild A from UΣVt
1142        assert_matrices_ulps_eq::<2, 2, _, _>(&u.matmul(&s.to_dense()).matmul(&vt), &matrix);
1143
1144        // Test against faer
1145        let faer = fill_faer(rows);
1146        let faersvd = faer.svd().unwrap();
1147        let (mut faeru, faers, mut faerv) =
1148            (faersvd.U().to_owned(), faersvd.S(), faersvd.V().to_owned());
1149
1150        if faeru.determinant().signum() != u.determinant().signum() {
1151            faeru[(0, 1)] *= -1.0;
1152            faeru[(1, 1)] *= -1.0;
1153        }
1154        if faerv.determinant().signum() != vt.determinant().signum() {
1155            faerv[(0, 1)] *= -1.0;
1156            faerv[(1, 1)] *= -1.0;
1157        }
1158
1159        assert_matrices_ulps_eq::<2, 2, _, _>(&u, &faeru);
1160        assert_diags_ulps_eq(&s, &faers);
1161        // Note that faer returns V, not Vt
1162        assert_matrices_ulps_eq::<2, 2, _, _>(&vt, &faerv.transpose());
1163    }
1164
1165    #[rstest(
1166        rows,
1167        case::identity(Matrix22::identity().rows),
1168        case::mixed_sign([[1.0, -2.0], [3.0, 4.0]]),
1169        case::det_zero([[12.0, 2.0], [4.0, 0.0]]),
1170        case::large_range([[1000.0, 0.0], [0.0, 1e-4]]),
1171        case::jordan_block([[1.0, 1.0], [0.0, 1.0]]),
1172        case::full_ones(Matrix22::full(1.0).rows),
1173        case::shear([[1.0, 2.0], [0.0, 1.0]]),
1174        case::nilpotent([[0.0, 1.0], [0.0, 0.0]]),
1175        case::scaling([[2.0, 0.0], [0.0, 3.0]]),
1176        case::reflect([[0.0, -1.0], [1.0, 0.0]]), // Numerical stability
1177        case::negative_identity((Matrix22::identity()*-1.0).rows),
1178        case::anti_diagonal([[0.0, 1.0], [1.0, 0.0]]),
1179        case::singular([[1.0, 2.0], [2.0, 4.0]]),
1180    )]
1181    fn test_svd_2x2_nalgebra(rows: [[f64; 2]; 2]) {
1182        let matrix = Matrix22 { rows };
1183        let (u, s, vt) = matrix.svd();
1184
1185        // Verify we can rebuild A from UΣVt
1186        assert_matrices_ulps_eq::<2, 2, _, _>(&u.matmul(&s.to_dense()).matmul(&vt), &matrix);
1187
1188        // Test against nalgebra
1189        let na = nalgebra::Matrix2::from(rows).transpose();
1190        let nasvd = na.svd(true, true);
1191        let (nau, nas, navt) = (nasvd.u.unwrap(), nasvd.singular_values, nasvd.v_t.unwrap());
1192
1193        assert_matrices_ulps_eq::<2, 2, _, _>(&u, &nau);
1194        assert_diags_ulps_eq::<2>(&s, &nas);
1195        assert_matrices_ulps_eq::<2, 2, _, _>(&vt, &navt);
1196    }
1197
1198    #[rstest(
1199        rows,
1200        case::identity(Matrix33::identity().rows),
1201        case::general([[1.0, 2.0, 3.0], [0.0, 1.0, 4.0], [5.0, 6.0, 0.0]]),
1202        case::symmetric([[1.0, 2.0, 3.0], [2.0, 5.0, 6.0], [3.0, 6.0, 9.0]]),
1203        case::near_singular([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.000_000_1]]),
1204        case::scaled((Matrix33::identity() * 10.0).rows),
1205        case::from_paper_repo([[0.433_807, 0.269_185, 0.543_034], [0.440_339, 0.443_024, 0.166_492], [0.793_913, 0.125_443, 0.730_333]]), // See https://github.com/ericjang/svd3
1206    )]
1207    fn test_svd_3x3_faer(rows: [[f64; 3]; 3]) {
1208        let matrix = Matrix33 { rows };
1209        let (u, s, vt) = matrix.svd();
1210
1211        // Verify reconstruction
1212        let m_recon = u.matmul(&s).matmul(&vt);
1213        assert_matrices_ulps_eq::<3, 3, _, _>(&m_recon, &matrix);
1214
1215        // Verify properties of U and V
1216        assert_relative_eq!(u.determinant(), 1.0, epsilon = EPS);
1217        assert_relative_eq!(vt.transpose().determinant(), 1.0, epsilon = EPS);
1218        assert_matrices_ulps_eq::<3, 3, _, _>(&u.matmul(&u.transpose()), &Matrix33::identity());
1219        assert_matrices_ulps_eq::<3, 3, _, _>(&vt.matmul(&vt.transpose()), &Matrix33::identity());
1220
1221        // Compare with faer SVD
1222        let faer_mat = fill_faer(rows);
1223        let faersvd = faer_mat.svd().unwrap();
1224
1225        let faers = faersvd.S();
1226        // Our implementation allows negative singular value
1227        assert_diags_ulps_eq(
1228            &DiagonalMatrix {
1229                elements: s.elements.map(f64::abs),
1230            },
1231            &faers,
1232        );
1233    }
1234
1235    #[test]
1236    fn test_transpose_2x2() {
1237        let rows = [[1.0, -2.0], [3.0, 4.0]];
1238        let matrix = Matrix::<2, 2> { rows };
1239        let faer_matrix = fill_faer(rows);
1240        let custom_transpose = matrix.transpose();
1241        let faer_transpose = faer_matrix.transpose();
1242        assert_matrices_ulps_eq::<2, 2, _, _>(&custom_transpose, &faer_transpose);
1243    }
1244
1245    #[test]
1246    fn test_transpose_2x3() {
1247        let rows = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
1248        let matrix = Matrix::<2, 3> { rows };
1249        let faer_matrix = fill_faer(rows);
1250        let custom_transpose = matrix.transpose();
1251        let faer_transpose = faer_matrix.transpose();
1252        assert_matrices_ulps_eq::<3, 2, _, _>(&custom_transpose, &faer_transpose);
1253    }
1254
1255    #[test]
1256    fn test_transpose_3x2() {
1257        let rows = [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
1258        let matrix = Matrix::<3, 2> { rows };
1259        let faer_matrix = fill_faer(rows);
1260        let custom_transpose = matrix.transpose();
1261        let faer_transpose = faer_matrix.transpose();
1262        assert_matrices_ulps_eq::<2, 3, _, _>(&custom_transpose, &faer_transpose);
1263    }
1264
1265    #[test]
1266    fn test_transpose_1x1() {
1267        let rows = [[-9.0]];
1268        let matrix = Matrix::<1, 1> { rows };
1269        assert_matrices_ulps_eq::<1, 1, _, _>(&matrix.transpose(), &matrix);
1270    }
1271
1272    #[test]
1273    fn test_general_matrix_methods() {
1274        // Matrix
1275        let zeros = Matrix::<2, 3>::zeros();
1276        let full = Matrix::<2, 3>::full(7.5);
1277        for i in 0..2 {
1278            for j in 0..3 {
1279                assert_eq!(zeros[(i, j)], 0.0);
1280                assert_eq!(full[(i, j)], 7.5);
1281            }
1282        }
1283    }
1284
1285    #[test]
1286    fn test_square_matrix_methods() {
1287        let identity = Matrix::<3, 3>::identity();
1288        let expected = Matrix::<3, 3> {
1289            rows: [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]],
1290        };
1291        assert_matrices_ulps_eq::<3, 3, _, _>(&identity, &expected);
1292    }
1293
1294    #[test]
1295    fn test_diag_conversions() {
1296        let mat = Matrix::<3, 3> {
1297            rows: [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]],
1298        };
1299        let diag = mat.diagonal();
1300        let expected_diag = DiagonalMatrix {
1301            elements: [1.0, 5.0, 9.0],
1302        };
1303        assert_diags_ulps_eq(&diag, &expected_diag);
1304
1305        let from_diag = diag.to_dense();
1306        let expected_from_diag = Matrix {
1307            rows: [[1.0, 0.0, 0.0], [0.0, 5.0, 0.0], [0.0, 0.0, 9.0]],
1308        };
1309        assert_matrices_ulps_eq::<3, 3, _, _>(&from_diag, &expected_from_diag);
1310    }
1311
1312    #[rstest(
1313        rows, vars,
1314        case(
1315            [[1.0, 2.0], [3.0, 4.0]],
1316            [0.5, 1.5]
1317        ),
1318        case(
1319            [[2.0, 0.0, 1.0], [3.0, 0.0, 0.0], [5.0, 1.0, 1.0]],
1320            [1.0, 2.0, 3.0]
1321        ),
1322        case(
1323            [[-33.0, 2.0, 0.0, 1.0], [3.0, -45.0, 0.0, 0.0], [5.0, 0.0, 1.0, 1.0], [0.0, 0.0, 0.0, 1.0]],
1324            [1.0, 2.0, 3.0, 4.0]
1325        ),
1326    )]
1327    fn test_quadratic_form<const N: usize>(rows: [[f64; N]; N], vars: [f64; N]) {
1328        let matrix = Matrix { rows };
1329        let result = matrix.compute_quadratic_form(&vars);
1330        assert_relative_eq!(
1331            result,
1332            (fill_faer_column(vars).transpose() * fill_faer(rows) * fill_faer_column(vars))[(0, 0)],
1333            max_relative = 1e-14
1334        );
1335    }
1336
1337    #[rstest(
1338        rows,
1339        case([[1.0, -2.0], [3.0, 4.0]]),
1340        case([[10.0, 0.0], [0.0, 0.1]]),
1341        case([[1.0, 1.0], [0.0, 1.0]]),
1342    )]
1343    fn test_inverse_2x2(rows: [[f64; 2]; 2]) {
1344        let matrix = Matrix22 { rows };
1345        let inv_matrix = matrix.inverse().expect("invertible");
1346        let product = matrix.matmul(&inv_matrix);
1347        let identity = Matrix22::identity();
1348
1349        assert_matrices_ulps_eq::<2, 2, _, _>(&product, &identity);
1350    }
1351    #[rstest(
1352        rows,
1353        case(Matrix33::identity().rows),
1354        case([[1.0, -3.0, 4.5], [3.0, 4.0,5.0], [8.0, -9.3, 10.0]]),
1355        case([[2.0, 1.0, 0.0], [0.0, 1.0, 2.0], [1.0, 0.0, 1.0]]),
1356        case([[5.0, -2.0, 3.0], [1.0, 0.0, 4.0], [-1.0, 2.0, 1.0]])
1357    )]
1358    fn test_inverse_3x3(rows: [[f64; 3]; 3]) {
1359        let matrix = Matrix33 { rows };
1360        let inv_matrix = matrix.inverse().expect("invertible");
1361        let product = matrix.matmul(&inv_matrix);
1362        let identity = Matrix33::identity();
1363
1364        assert_matrices_ulps_eq::<3, 3, _, _>(&product, &identity);
1365    }
1366    #[rstest(
1367        rows,
1368        case(Matrix44::identity().rows),
1369        case([[1.0, -4.0, 4.5,1.0], [4.0, 4.0,5.0,0.0], [8.0, -9.4, 10.0,9.0], [-1.0,-1.0,1.0,1.0]]),
1370    )]
1371    fn test_inverse_4x4(rows: [[f64; 4]; 4]) {
1372        let matrix = Matrix44 { rows };
1373        let inv_matrix = matrix.inverse().expect("invertible");
1374        let product = matrix.matmul(&inv_matrix);
1375        let identity = Matrix44::identity();
1376
1377        assert_matrices_ulps_eq::<4, 4, _, _>(&product, &identity);
1378    }
1379}