hoomd_linear_algebra/
diagonal.rs

1// Copyright (c) 2024-2026 The Regents of the University of Michigan.
2// Part of hoomd-rs, released under the BSD 3-Clause License.
3
4use serde::{Deserialize, Serialize};
5use serde_with::serde_as;
6use std::ops::{Add, AddAssign, Index, IndexMut, Mul, MulAssign, Neg, Sub, SubAssign};
7
8use crate::{GeneralMatrix, matrix::Matrix};
9
10/// A square, diagonal matrix with N rows and N columns.
11///
12/// `matrix.elements[i]` is the matrix element $` A_{ii} `$.
13/// All off-diagonal elements are 0.
14///
15/// # Example
16/// ```
17/// use hoomd_linear_algebra::matrix::DiagonalMatrix;
18/// let a = DiagonalMatrix {
19///     elements: [-2.0, 3.0],
20/// };
21///
22/// assert_eq!(a[(0, 0)], -2.0);
23/// assert_eq!(a[(0, 1)], 0.0);
24/// assert_eq!(a[(1, 0)], 0.0);
25/// assert_eq!(a[(1, 1)], 3.0);
26/// ```
27#[serde_as]
28#[derive(Clone, Copy, Debug, PartialEq, Serialize, Deserialize)]
29pub struct DiagonalMatrix<const N: usize> {
30    /// The members of the diagonal of the matrix
31    #[serde_as(as = "[_; N]")]
32    pub elements: [f64; N],
33}
34
35impl<const N: usize> Index<usize> for DiagonalMatrix<N> {
36    type Output = f64;
37
38    /// Index the diagonal components.
39    ///
40    /// # Example
41    /// ```
42    /// use hoomd_linear_algebra::{SquareMatrix, matrix::DiagonalMatrix};
43    /// let a = DiagonalMatrix {
44    ///     elements: [1.0, 2.0, 3.0],
45    /// };
46    /// assert_eq!(a[0], 1.0);
47    /// assert_eq!(a[1], 2.0);
48    /// assert_eq!(a[2], 3.0);
49    /// ```
50    #[inline]
51    fn index(&self, index: usize) -> &f64 {
52        &self.elements[index]
53    }
54}
55impl<const N: usize> IndexMut<usize> for DiagonalMatrix<N> {
56    #[inline]
57    fn index_mut(&mut self, index: usize) -> &mut f64 {
58        &mut self.elements[index]
59    }
60}
61
62impl<const N: usize> Index<(usize, usize)> for DiagonalMatrix<N> {
63    type Output = f64;
64
65    /// Index matrix elements by (row, column).
66    ///
67    /// # Example
68    /// ```
69    /// use hoomd_linear_algebra::matrix::DiagonalMatrix;
70    /// let a = DiagonalMatrix {
71    ///     elements: [1.0, 2.0, 3.0],
72    /// };
73    /// assert_eq!(a[(0, 0)], 1.0);
74    /// assert_eq!(a[(1, 1)], 2.0);
75    /// assert_eq!(a[(0, 2)], 0.0);
76    /// assert_eq!(a[(2, 2)], 3.0);
77    /// ```
78    #[inline]
79    fn index(&self, index: (usize, usize)) -> &f64 {
80        let (i, j) = index;
81        if i == j { &self.elements[i] } else { &0.0 }
82    }
83}
84
85impl<const N: usize> GeneralMatrix for DiagonalMatrix<N> {
86    #[inline]
87    fn zeros() -> Self {
88        Self {
89            elements: std::array::from_fn(|_| 0.0),
90        }
91    }
92    #[inline]
93    fn shape(&self) -> (usize, usize) {
94        (N, N)
95    }
96}
97
98impl<const N: usize> DiagonalMatrix<N> {
99    /// Construct a dense matrix with the given diagonal.
100    ///
101    /// # Example
102    /// ```
103    /// use hoomd_linear_algebra::matrix::DiagonalMatrix;
104    ///
105    /// let a = DiagonalMatrix {
106    ///     elements: [-2.0, 3.0],
107    /// };
108    /// let b = a.to_dense();
109    /// assert_eq!(b.rows, [[-2.0, 0.0], [0.0, 3.0]]);
110    /// ```
111    #[must_use]
112    #[inline]
113    pub fn to_dense(self) -> Matrix<N, N> {
114        Matrix {
115            rows: std::array::from_fn(|i| {
116                std::array::from_fn(|j| if i == j { self.elements[i] } else { 0.0 })
117            }),
118        }
119    }
120}
121
122impl<const N: usize> Mul<f64> for DiagonalMatrix<N> {
123    type Output = Self;
124
125    /// Multiply a diagonal matrix by a scalar.
126    ///
127    /// # Example
128    /// ```
129    /// use hoomd_linear_algebra::matrix::DiagonalMatrix;
130    /// let a = DiagonalMatrix {
131    ///     elements: [-3.0, 2.0, -8.0],
132    /// };
133    ///
134    /// let b = a * 3.0;
135    /// assert_eq!(b[0], -9.0);
136    /// assert_eq!(b[1], 6.0);
137    /// assert_eq!(b[2], -24.0);
138    /// ```
139    #[inline]
140    fn mul(self, rhs: f64) -> Self {
141        Self {
142            elements: self.elements.map(|r| r * rhs),
143        }
144    }
145}
146
147impl<const N: usize> Neg for DiagonalMatrix<N> {
148    type Output = Self;
149
150    /// Negate a diagonal matrix.
151    ///
152    /// # Example
153    /// ```
154    /// use hoomd_linear_algebra::matrix::DiagonalMatrix;
155    /// let a = DiagonalMatrix {
156    ///     elements: [-3.0, 2.0, -8.0],
157    /// };
158    ///
159    /// let b = -a;
160    /// assert_eq!(b[0], 3.0);
161    /// assert_eq!(b[1], -2.0);
162    /// assert_eq!(b[2], 8.0);
163    /// ```
164    #[inline]
165    fn neg(self) -> Self {
166        Self {
167            elements: self.elements.map(|r| -r),
168        }
169    }
170}
171
172impl<const N: usize> Add<Self> for DiagonalMatrix<N> {
173    type Output = Self;
174
175    /// Add two diagonal matrices.
176    ///
177    /// # Example
178    /// ```
179    /// use hoomd_linear_algebra::matrix::DiagonalMatrix;
180    /// let a = DiagonalMatrix {
181    ///     elements: [-3.0, 2.0, -8.0],
182    /// };
183    /// let b = DiagonalMatrix {
184    ///     elements: [4.0, -4.0, 12.0],
185    /// };
186    ///
187    /// let c = a + b;
188    /// assert_eq!(c[0], 1.0);
189    /// assert_eq!(c[1], -2.0);
190    /// assert_eq!(c[2], 4.0);
191    /// ```
192    #[inline]
193    fn add(self, rhs: Self) -> Self {
194        Self {
195            elements: std::array::from_fn(|i| self[i] + rhs[i]),
196        }
197    }
198}
199impl<const N: usize> Sub<Self> for DiagonalMatrix<N> {
200    type Output = Self;
201
202    /// Subtract two diagonal matrices.
203    ///
204    /// # Example
205    /// ```
206    /// use hoomd_linear_algebra::matrix::DiagonalMatrix;
207    /// let a = DiagonalMatrix {
208    ///     elements: [-3.0, 2.0, -8.0],
209    /// };
210    /// let b = DiagonalMatrix {
211    ///     elements: [4.0, -4.0, 12.0],
212    /// };
213    ///
214    /// let c = a - b;
215    /// assert_eq!(c[0], -7.0);
216    /// assert_eq!(c[1], 6.0);
217    /// assert_eq!(c[2], -20.0);
218    /// ```
219    #[inline]
220    fn sub(self, rhs: Self) -> Self {
221        Self {
222            elements: std::array::from_fn(|i| self[i] - rhs[i]),
223        }
224    }
225}
226
227impl<const N: usize> AddAssign for DiagonalMatrix<N> {
228    #[inline]
229    fn add_assign(&mut self, rhs: Self) {
230        self.elements
231            .iter_mut()
232            .zip(rhs.elements.iter())
233            .for_each(|(x, r)| *x += r);
234    }
235}
236
237impl<const N: usize> MulAssign<f64> for DiagonalMatrix<N> {
238    #[inline]
239    fn mul_assign(&mut self, rhs: f64) {
240        self.elements.iter_mut().for_each(|x| *x *= rhs);
241    }
242}
243
244impl<const N: usize> SubAssign for DiagonalMatrix<N> {
245    #[inline]
246    fn sub_assign(&mut self, rhs: Self) {
247        self.elements
248            .iter_mut()
249            .zip(rhs.elements.iter())
250            .for_each(|(x, r)| *x -= r);
251    }
252}
253
254#[cfg(test)]
255mod tests {
256    use approxim::assert_ulps_eq;
257    use rstest::rstest;
258    use std::ops::Index;
259
260    use super::*;
261    use crate::GeneralMatrix;
262
263    fn assert_diags_ulps_eq<const N: usize>(
264        m0: &DiagonalMatrix<N>,
265        m1: &impl Index<usize, Output = f64>,
266    ) {
267        for i in 0..N {
268            assert_ulps_eq!(m0[i], m1[i], epsilon = 1e-13);
269        }
270    }
271
272    #[test]
273    fn test_diagonal_matrix_add_n2() {
274        let a_diag = [1.0, 2.0];
275        let b_diag = [3.0, 4.0];
276        let a = DiagonalMatrix { elements: a_diag };
277        let b = DiagonalMatrix { elements: b_diag };
278        let expected: Vec<f64> = a_diag
279            .iter()
280            .zip(b_diag.iter())
281            .map(|(x, y)| x + y)
282            .collect();
283        let custom_sum = a + b;
284        assert_diags_ulps_eq(&custom_sum, &expected);
285    }
286
287    #[test]
288    fn test_diagonal_matrix_add_n3() {
289        let a_diag = [1.0, 2.0, 3.0];
290        let b_diag = [4.0, 5.0, 6.0];
291        let a = DiagonalMatrix { elements: a_diag };
292        let b = DiagonalMatrix { elements: b_diag };
293        let expected: Vec<f64> = a_diag
294            .iter()
295            .zip(b_diag.iter())
296            .map(|(x, y)| x + y)
297            .collect();
298        let custom_sum = a + b;
299        assert_diags_ulps_eq(&custom_sum, &expected);
300    }
301
302    #[test]
303    fn test_diagonal_matrix_add_assign_n3() {
304        let a_diag = [1.0, 2.0, 3.0];
305        let b_diag = [4.0, 5.0, 6.0];
306        let mut a = DiagonalMatrix { elements: a_diag };
307        let b = DiagonalMatrix { elements: b_diag };
308        let expected: Vec<f64> = a_diag
309            .iter()
310            .zip(b_diag.iter())
311            .map(|(x, y)| x + y)
312            .collect();
313        a += b;
314        assert_diags_ulps_eq(&a, &expected);
315    }
316
317    #[test]
318    fn test_diagonal_matrix_sub_n2() {
319        let a_diag = [1.0, 2.0];
320        let b_diag = [3.0, 4.0];
321        let a = DiagonalMatrix { elements: a_diag };
322        let b = DiagonalMatrix { elements: b_diag };
323        let expected: Vec<f64> = a_diag
324            .iter()
325            .zip(b_diag.iter())
326            .map(|(x, y)| x - y)
327            .collect();
328        let custom_sub = a - b;
329        assert_diags_ulps_eq(&custom_sub, &expected);
330    }
331
332    #[test]
333    fn test_diagonal_matrix_sub_n3() {
334        let a_diag = [1.0, 2.0, 3.0];
335        let b_diag = [4.0, 5.0, 6.0];
336        let a = DiagonalMatrix { elements: a_diag };
337        let b = DiagonalMatrix { elements: b_diag };
338        let expected: Vec<f64> = a_diag
339            .iter()
340            .zip(b_diag.iter())
341            .map(|(x, y)| x - y)
342            .collect();
343        let custom_sub = a - b;
344        assert_diags_ulps_eq(&custom_sub, &expected);
345    }
346
347    #[test]
348    fn test_diagonal_matrix_sub_assign_n3() {
349        let a_diag = [1.0, 2.0, 3.0];
350        let b_diag = [4.0, 5.0, 6.0];
351        let mut a = DiagonalMatrix { elements: a_diag };
352        let b = DiagonalMatrix { elements: b_diag };
353        let expected: Vec<f64> = a_diag
354            .iter()
355            .zip(b_diag.iter())
356            .map(|(x, y)| x - y)
357            .collect();
358        a -= b;
359        assert_diags_ulps_eq(&a, &expected);
360    }
361
362    #[test]
363    fn test_diagonal_matrix_neg_n2() {
364        let diag = [1.0, -2.0];
365        let matrix = DiagonalMatrix { elements: diag };
366        let expected: Vec<f64> = diag.iter().map(|x| -x).collect();
367        let custom_neg = -matrix;
368        assert_diags_ulps_eq(&custom_neg, &expected);
369    }
370
371    #[test]
372    fn test_diagonal_matrix_neg_n3() {
373        let diag = [1.0, -2.0, 0.0];
374        let matrix = DiagonalMatrix { elements: diag };
375        let expected: Vec<f64> = diag.iter().map(|x| -x).collect();
376        let custom_neg = -matrix;
377        assert_diags_ulps_eq(&custom_neg, &expected);
378    }
379
380    #[rstest]
381    #[case([1.0, 2.0], 5.0)]
382    #[case([1.0, 2.0], -1.0)]
383    #[case([1.0, 2.0], 0.0)]
384    fn test_diagonal_matrix_scalar_mul_n2(#[case] diag: [f64; 2], #[case] scalar: f64) {
385        let matrix = DiagonalMatrix { elements: diag };
386        let expected: Vec<f64> = diag.iter().map(|x| x * scalar).collect();
387        let custom_mul = matrix * scalar;
388        assert_diags_ulps_eq(&custom_mul, &expected);
389    }
390
391    #[rstest]
392    #[case([1.0, 2.0], 5.0)]
393    #[case([1.0, 2.0], -1.0)]
394    #[case([1.0, 2.0], 0.0)]
395    fn test_diagonal_matrix_scalar_mul_assign_n2(#[case] diag: [f64; 2], #[case] scalar: f64) {
396        let mut matrix = DiagonalMatrix { elements: diag };
397        let expected: Vec<f64> = diag.iter().map(|x| x * scalar).collect();
398        matrix *= scalar;
399        assert_diags_ulps_eq(&matrix, &expected);
400    }
401
402    #[test]
403    fn test_indexing() {
404        // DiagonalMatrix
405        let diag_mat = DiagonalMatrix::<3> {
406            elements: [1.0, 2.0, 3.0],
407        };
408        assert_eq!(diag_mat[1], 2.0); // 1D indexing
409        assert_eq!(diag_mat[(2, 2)], 3.0); // 2D on-diagonal
410        assert_eq!(diag_mat[(0, 1)], 0.0); // 2D off-diagonal
411    }
412
413    #[test]
414    fn test_general_matrix_methods() {
415        // DiagonalMatrix
416        let diag_zeros = DiagonalMatrix::<4>::zeros();
417        for i in 0..4 {
418            assert_eq!(diag_zeros[i], 0.0);
419        }
420    }
421}