hoomd_linear_algebra/matrix/
ops.rs

1// Copyright (c) 2024-2026 The Regents of the University of Michigan.
2// Part of hoomd-rs, released under the BSD 3-Clause License.
3
4use super::Matrix;
5use std::ops::{Add, AddAssign, Index, IndexMut, Mul, MulAssign, Neg, Sub, SubAssign};
6
7impl<const N: usize, const M: usize> Index<(usize, usize)> for Matrix<N, M> {
8    type Output = f64;
9
10    /// Access matrix elements..
11    ///
12    /// Elements are indexed by `(row, column)`.
13    ///
14    /// # Examples
15    /// ```
16    /// use hoomd_linear_algebra::matrix::Matrix;
17    ///
18    /// let rows = [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
19    /// let a = Matrix { rows };
20    /// assert_eq!(a[(0, 1)], rows[0][1]);
21    /// assert_eq!(a[(2, 1)], 6.0);
22    /// assert_eq!(a[(1, 1)], 4.0);
23    /// ```
24    #[inline]
25    fn index(&self, index: (usize, usize)) -> &f64 {
26        let (i, j) = index;
27        &self.rows[i][j]
28    }
29}
30impl<const N: usize, const M: usize> IndexMut<(usize, usize)> for Matrix<N, M> {
31    #[inline]
32    fn index_mut(&mut self, index: (usize, usize)) -> &mut f64 {
33        let (i, j) = index;
34        &mut self.rows[i][j]
35    }
36}
37
38impl<const N: usize, const M: usize> Mul<f64> for Matrix<N, M> {
39    type Output = Self;
40
41    #[inline]
42    /// Matrix-scalar multiplication.
43    ///
44    /// # Examples
45    /// ```
46    /// use hoomd_linear_algebra::{Full, GeneralMatrix, matrix::Matrix22};
47    ///
48    /// let matrix = Matrix22::full(2.0);
49    /// let scalar = 2.0;
50    /// assert_eq!(matrix * scalar, matrix + matrix);
51    /// ```
52    fn mul(self, rhs: f64) -> Self {
53        self.map_elements(|x| x * rhs)
54    }
55}
56
57impl<const N: usize, const M: usize> Mul<Matrix<N, M>> for f64 {
58    type Output = Matrix<N, M>;
59
60    /// Matrix-scalar multiplication.
61    ///
62    /// # Examples
63    /// ```
64    /// use hoomd_linear_algebra::{Full, GeneralMatrix, matrix::Matrix22};
65    ///
66    /// let matrix = Matrix22::full(2.0);
67    /// let scalar = 3.0;
68    /// assert_eq!(scalar * matrix, matrix * scalar);
69    /// ```
70    #[inline]
71    fn mul(self, rhs: Self::Output) -> Self::Output {
72        rhs.map_elements(|x| x * self)
73    }
74}
75
76impl<const N: usize, const M: usize> MulAssign<f64> for Matrix<N, M> {
77    #[inline]
78    /// Matrix-scalar multiplication assignment.
79    ///
80    /// # Examples
81    /// ```
82    /// use hoomd_linear_algebra::{Full, GeneralMatrix, matrix::Matrix22};
83    ///
84    /// let mut matrix = Matrix22::full(2.0);
85    /// let matrix_copy = matrix.clone();
86    /// matrix *= 3.0;
87    /// assert_eq!(matrix, matrix_copy * 3.0);
88    /// ```
89    fn mul_assign(&mut self, rhs: f64) {
90        self.iter_elements_mut().for_each(|x| *x *= rhs);
91    }
92}
93
94impl<const N: usize, const M: usize> Neg for Matrix<N, M> {
95    type Output = Self;
96
97    /// Matrix negation.
98    ///
99    /// # Examples
100    /// ```
101    /// use hoomd_linear_algebra::{Full, GeneralMatrix, matrix::Matrix22};
102    ///
103    /// let matrix = Matrix22::full(5.0);
104    /// assert_eq!(-matrix, Matrix22::zeros() - matrix);
105    /// ```
106    #[inline]
107    fn neg(self) -> Self {
108        self.map_elements(f64::neg)
109    }
110}
111
112impl<const N: usize, const M: usize> Add<Self> for Matrix<N, M> {
113    type Output = Self;
114
115    #[inline]
116    fn add(self, rhs: Self) -> Self {
117        Self {
118            rows: std::array::from_fn(|i| {
119                std::array::from_fn(|j| self.rows[i][j] + rhs.rows[i][j])
120            }),
121        }
122    }
123}
124impl<const N: usize, const M: usize> AddAssign for Matrix<N, M> {
125    #[inline]
126    fn add_assign(&mut self, rhs: Self) {
127        self.iter_elements_mut()
128            .zip(rhs.iter_elements())
129            .for_each(|(x, r)| *x += r);
130    }
131}
132impl<const N: usize, const M: usize> Sub<Self> for Matrix<N, M> {
133    type Output = Self;
134
135    #[inline]
136    fn sub(self, rhs: Self) -> Self {
137        Self {
138            rows: std::array::from_fn(|i| {
139                std::array::from_fn(|j| self.rows[i][j] - rhs.rows[i][j])
140            }),
141        }
142    }
143}
144impl<const N: usize, const M: usize> SubAssign for Matrix<N, M> {
145    #[inline]
146    fn sub_assign(&mut self, rhs: Self) {
147        self.iter_elements_mut()
148            .zip(rhs.iter_elements())
149            .for_each(|(x, r)| *x -= r);
150    }
151}
152
153#[cfg(test)]
154mod tests {
155    use crate::{GeneralMatrix, matrix::Matrix};
156    use rstest::rstest;
157
158    #[test]
159    fn test_matrix_add_2x2() {
160        let a_rows = [[1.0, 2.0], [3.0, 4.0]];
161        let b_rows = [[5.0, 6.0], [7.0, 8.0]];
162
163        let a = Matrix::<2, 2> { rows: a_rows };
164        let b = Matrix::<2, 2> { rows: b_rows };
165        let c = Matrix::<2, 2> {
166            rows: [[6.0, 8.0], [10.0, 12.0]],
167        };
168
169        assert_eq!(a + b, c);
170    }
171
172    #[test]
173    fn test_matrix_add_2x3() {
174        let a_rows = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
175        let b_rows = [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]];
176        let a = Matrix::<2, 3> { rows: a_rows };
177        let b = Matrix::<2, 3> { rows: b_rows };
178        let c = Matrix::<2, 3> {
179            rows: [[8.0, 10.0, 12.0], [14.0, 16.0, 18.0]],
180        };
181
182        assert_eq!(a + b, c);
183    }
184
185    #[test]
186    fn test_matrix_sub_2x2() {
187        let a_rows = [[1.0, 2.0], [3.0, 4.0]];
188        let b_rows = [[5.0, 6.0], [7.0, 8.0]];
189        let a = Matrix::<2, 2> { rows: a_rows };
190        let b = Matrix::<2, 2> { rows: b_rows };
191        let c = Matrix::<2, 2> {
192            rows: [[-4.0, -4.0], [-4.0, -4.0]],
193        };
194        assert_eq!(a - b, c);
195    }
196
197    #[test]
198    fn test_matrix_sub_2x3() {
199        let a_rows = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
200        let b_rows = [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]];
201        let a = Matrix::<2, 3> { rows: a_rows };
202        let b = Matrix::<2, 3> { rows: b_rows };
203        let c = Matrix::<2, 3> {
204            rows: [[-6.0, -6.0, -6.0], [-6.0, -6.0, -6.0]],
205        };
206        assert_eq!(a - b, c);
207    }
208
209    #[rstest(
210        rows,
211        case([[1.0, -2.0], [3.0, 4.0]]),
212        case([[0.0, 0.0], [0.0, 0.0]])
213    )]
214    fn test_matrix_neg_2x2(rows: [[f64; 2]; 2]) {
215        let matrix = Matrix::<2, 2> { rows };
216        let expected = Matrix {
217            rows: rows.map(|row| row.map(|x| -x)),
218        };
219        assert_eq!(-matrix, expected);
220    }
221
222    #[rstest]
223    #[case([[1.0, 2.0], [3.0, 4.0]], 5.0)]
224    #[case([[1.0, 2.0], [3.0, 4.0]], -1.0)]
225    #[case([[1.0, 2.0], [3.0, 4.0]], 0.0)]
226    fn test_matrix_scalar_mul_2x2(#[case] rows: [[f64; 2]; 2], #[case] scalar: f64) {
227        let matrix = Matrix::<2, 2> { rows };
228        let expected = Matrix {
229            rows: rows.map(|row| row.map(|x| x * scalar)),
230        };
231        assert_eq!(matrix * scalar, expected);
232    }
233
234    #[test]
235    fn test_indexing() {
236        // Matrix
237        let mat = Matrix::<2, 3> {
238            rows: [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]],
239        };
240        assert_eq!(mat[(0, 2)], 3.0);
241        assert_eq!(mat[(1, 1)], 5.0);
242    }
243
244    #[test]
245    fn test_mut_indexing() {
246        let mut mat = Matrix::<2, 2>::zeros();
247        mat[(0, 1)] = 99.0;
248        mat[(1, 0)] = -5.5;
249        assert_eq!(mat[(0, 1)], 99.0);
250        assert_eq!(mat[(1, 0)], -5.5);
251        assert_eq!(mat[(1, 1)], 0.0);
252    }
253
254    #[rstest]
255    #[case(
256        [[1.0, 2.0], [ 3.0, 4.0], [5.0, 6.0]],
257        [[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]],
258    )]
259    #[case(
260        [[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]],
261        [[2.0, 3.0], [ 4.0, 5.0], [6.0, 7.0]],
262    )]
263    #[case(
264        [[1.0],[ 2.0]],
265        [[3.0], [4.0]],
266    )]
267    #[case(
268        [[1.0, 2.0], [3.0, 4.0], [1.0, 1.0]],
269        [[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]],
270    )]
271    fn test_add_assign<const M: usize, const N: usize>(
272        #[case] a_rows: [[f64; M]; N],
273        #[case] b_rows: [[f64; M]; N],
274    ) {
275        let mut a = Matrix { rows: a_rows };
276        let b = Matrix { rows: b_rows };
277        let c = a.clone() + b.clone();
278
279        a += b;
280        assert_eq!(a, c);
281    }
282    #[rstest]
283    #[case(
284        [[1.0, 2.0], [ 3.0, 4.0], [5.0, 6.0]],
285        [[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]],
286    )]
287    #[case(
288        [[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]],
289        [[2.0, 3.0], [ 4.0, 5.0], [6.0, 7.0]],
290    )]
291    #[case(
292        [[1.0],[ 2.0]],
293        [[3.0], [4.0]],
294    )]
295    #[case(
296        [[1.0, 2.0], [3.0, 4.0], [1.0, 1.0]],
297        [[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]],
298    )]
299    fn test_sub_assign<const M: usize, const N: usize>(
300        #[case] a_rows: [[f64; M]; N],
301        #[case] b_rows: [[f64; M]; N],
302    ) {
303        let mut a = Matrix { rows: a_rows };
304        let b = Matrix { rows: b_rows };
305        let c = a.clone() - b.clone();
306
307        a -= b;
308        assert_eq!(a, c);
309    }
310    #[rstest]
311    #[case(
312        [[1.0, 2.0], [ 3.0, 4.0], [5.0, 6.0]], 0.0
313    )]
314    #[case(
315        [[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]], -91.0
316    )]
317    #[case(
318        [[1.0],[ 2.0]], 33.3
319    )]
320    #[case(
321        [[1.0, 2.0], [3.0, 4.0], [1.0, 1.0]], 84.0
322    )]
323    fn test_mul_assign<const M: usize, const N: usize>(
324        #[case] a_rows: [[f64; M]; N],
325        #[case] x: f64,
326    ) {
327        let mut a = Matrix { rows: a_rows };
328        let c = a.clone() * x;
329
330        a *= x;
331        assert_eq!(a, c);
332    }
333
334    #[rstest]
335    #[case(
336        [[1.0, 2.0], [ 3.0, 4.0], [5.0, 6.0]], 0.0
337    )]
338    #[case(
339        [[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]], -91.0
340    )]
341    #[case(
342        [[1.0],[ 2.0]], 33.3
343    )]
344    #[case(
345        [[1.0, 2.0], [3.0, 4.0], [1.0, 1.0]], 84.0
346    )]
347    fn test_mul_left<const M: usize, const N: usize>(
348        #[case] a_rows: [[f64; M]; N],
349        #[case] x: f64,
350    ) {
351        let a = Matrix { rows: a_rows };
352        assert_eq!(a.clone() * x, x * a);
353    }
354}