1use super::Matrix;
5use std::ops::{Add, AddAssign, Index, IndexMut, Mul, MulAssign, Neg, Sub, SubAssign};
6
7impl<const N: usize, const M: usize> Index<(usize, usize)> for Matrix<N, M> {
8 type Output = f64;
9
10 #[inline]
25 fn index(&self, index: (usize, usize)) -> &f64 {
26 let (i, j) = index;
27 &self.rows[i][j]
28 }
29}
30impl<const N: usize, const M: usize> IndexMut<(usize, usize)> for Matrix<N, M> {
31 #[inline]
32 fn index_mut(&mut self, index: (usize, usize)) -> &mut f64 {
33 let (i, j) = index;
34 &mut self.rows[i][j]
35 }
36}
37
38impl<const N: usize, const M: usize> Mul<f64> for Matrix<N, M> {
39 type Output = Self;
40
41 #[inline]
42 fn mul(self, rhs: f64) -> Self {
53 self.map_elements(|x| x * rhs)
54 }
55}
56
57impl<const N: usize, const M: usize> Mul<Matrix<N, M>> for f64 {
58 type Output = Matrix<N, M>;
59
60 #[inline]
71 fn mul(self, rhs: Self::Output) -> Self::Output {
72 rhs.map_elements(|x| x * self)
73 }
74}
75
76impl<const N: usize, const M: usize> MulAssign<f64> for Matrix<N, M> {
77 #[inline]
78 fn mul_assign(&mut self, rhs: f64) {
90 self.iter_elements_mut().for_each(|x| *x *= rhs);
91 }
92}
93
94impl<const N: usize, const M: usize> Neg for Matrix<N, M> {
95 type Output = Self;
96
97 #[inline]
107 fn neg(self) -> Self {
108 self.map_elements(f64::neg)
109 }
110}
111
112impl<const N: usize, const M: usize> Add<Self> for Matrix<N, M> {
113 type Output = Self;
114
115 #[inline]
116 fn add(self, rhs: Self) -> Self {
117 Self {
118 rows: std::array::from_fn(|i| {
119 std::array::from_fn(|j| self.rows[i][j] + rhs.rows[i][j])
120 }),
121 }
122 }
123}
124impl<const N: usize, const M: usize> AddAssign for Matrix<N, M> {
125 #[inline]
126 fn add_assign(&mut self, rhs: Self) {
127 self.iter_elements_mut()
128 .zip(rhs.iter_elements())
129 .for_each(|(x, r)| *x += r);
130 }
131}
132impl<const N: usize, const M: usize> Sub<Self> for Matrix<N, M> {
133 type Output = Self;
134
135 #[inline]
136 fn sub(self, rhs: Self) -> Self {
137 Self {
138 rows: std::array::from_fn(|i| {
139 std::array::from_fn(|j| self.rows[i][j] - rhs.rows[i][j])
140 }),
141 }
142 }
143}
144impl<const N: usize, const M: usize> SubAssign for Matrix<N, M> {
145 #[inline]
146 fn sub_assign(&mut self, rhs: Self) {
147 self.iter_elements_mut()
148 .zip(rhs.iter_elements())
149 .for_each(|(x, r)| *x -= r);
150 }
151}
152
153#[cfg(test)]
154mod tests {
155 use crate::{GeneralMatrix, matrix::Matrix};
156 use rstest::rstest;
157
158 #[test]
159 fn test_matrix_add_2x2() {
160 let a_rows = [[1.0, 2.0], [3.0, 4.0]];
161 let b_rows = [[5.0, 6.0], [7.0, 8.0]];
162
163 let a = Matrix::<2, 2> { rows: a_rows };
164 let b = Matrix::<2, 2> { rows: b_rows };
165 let c = Matrix::<2, 2> {
166 rows: [[6.0, 8.0], [10.0, 12.0]],
167 };
168
169 assert_eq!(a + b, c);
170 }
171
172 #[test]
173 fn test_matrix_add_2x3() {
174 let a_rows = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
175 let b_rows = [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]];
176 let a = Matrix::<2, 3> { rows: a_rows };
177 let b = Matrix::<2, 3> { rows: b_rows };
178 let c = Matrix::<2, 3> {
179 rows: [[8.0, 10.0, 12.0], [14.0, 16.0, 18.0]],
180 };
181
182 assert_eq!(a + b, c);
183 }
184
185 #[test]
186 fn test_matrix_sub_2x2() {
187 let a_rows = [[1.0, 2.0], [3.0, 4.0]];
188 let b_rows = [[5.0, 6.0], [7.0, 8.0]];
189 let a = Matrix::<2, 2> { rows: a_rows };
190 let b = Matrix::<2, 2> { rows: b_rows };
191 let c = Matrix::<2, 2> {
192 rows: [[-4.0, -4.0], [-4.0, -4.0]],
193 };
194 assert_eq!(a - b, c);
195 }
196
197 #[test]
198 fn test_matrix_sub_2x3() {
199 let a_rows = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
200 let b_rows = [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]];
201 let a = Matrix::<2, 3> { rows: a_rows };
202 let b = Matrix::<2, 3> { rows: b_rows };
203 let c = Matrix::<2, 3> {
204 rows: [[-6.0, -6.0, -6.0], [-6.0, -6.0, -6.0]],
205 };
206 assert_eq!(a - b, c);
207 }
208
209 #[rstest(
210 rows,
211 case([[1.0, -2.0], [3.0, 4.0]]),
212 case([[0.0, 0.0], [0.0, 0.0]])
213 )]
214 fn test_matrix_neg_2x2(rows: [[f64; 2]; 2]) {
215 let matrix = Matrix::<2, 2> { rows };
216 let expected = Matrix {
217 rows: rows.map(|row| row.map(|x| -x)),
218 };
219 assert_eq!(-matrix, expected);
220 }
221
222 #[rstest]
223 #[case([[1.0, 2.0], [3.0, 4.0]], 5.0)]
224 #[case([[1.0, 2.0], [3.0, 4.0]], -1.0)]
225 #[case([[1.0, 2.0], [3.0, 4.0]], 0.0)]
226 fn test_matrix_scalar_mul_2x2(#[case] rows: [[f64; 2]; 2], #[case] scalar: f64) {
227 let matrix = Matrix::<2, 2> { rows };
228 let expected = Matrix {
229 rows: rows.map(|row| row.map(|x| x * scalar)),
230 };
231 assert_eq!(matrix * scalar, expected);
232 }
233
234 #[test]
235 fn test_indexing() {
236 let mat = Matrix::<2, 3> {
238 rows: [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]],
239 };
240 assert_eq!(mat[(0, 2)], 3.0);
241 assert_eq!(mat[(1, 1)], 5.0);
242 }
243
244 #[test]
245 fn test_mut_indexing() {
246 let mut mat = Matrix::<2, 2>::zeros();
247 mat[(0, 1)] = 99.0;
248 mat[(1, 0)] = -5.5;
249 assert_eq!(mat[(0, 1)], 99.0);
250 assert_eq!(mat[(1, 0)], -5.5);
251 assert_eq!(mat[(1, 1)], 0.0);
252 }
253
254 #[rstest]
255 #[case(
256 [[1.0, 2.0], [ 3.0, 4.0], [5.0, 6.0]],
257 [[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]],
258 )]
259 #[case(
260 [[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]],
261 [[2.0, 3.0], [ 4.0, 5.0], [6.0, 7.0]],
262 )]
263 #[case(
264 [[1.0],[ 2.0]],
265 [[3.0], [4.0]],
266 )]
267 #[case(
268 [[1.0, 2.0], [3.0, 4.0], [1.0, 1.0]],
269 [[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]],
270 )]
271 fn test_add_assign<const M: usize, const N: usize>(
272 #[case] a_rows: [[f64; M]; N],
273 #[case] b_rows: [[f64; M]; N],
274 ) {
275 let mut a = Matrix { rows: a_rows };
276 let b = Matrix { rows: b_rows };
277 let c = a.clone() + b.clone();
278
279 a += b;
280 assert_eq!(a, c);
281 }
282 #[rstest]
283 #[case(
284 [[1.0, 2.0], [ 3.0, 4.0], [5.0, 6.0]],
285 [[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]],
286 )]
287 #[case(
288 [[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]],
289 [[2.0, 3.0], [ 4.0, 5.0], [6.0, 7.0]],
290 )]
291 #[case(
292 [[1.0],[ 2.0]],
293 [[3.0], [4.0]],
294 )]
295 #[case(
296 [[1.0, 2.0], [3.0, 4.0], [1.0, 1.0]],
297 [[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]],
298 )]
299 fn test_sub_assign<const M: usize, const N: usize>(
300 #[case] a_rows: [[f64; M]; N],
301 #[case] b_rows: [[f64; M]; N],
302 ) {
303 let mut a = Matrix { rows: a_rows };
304 let b = Matrix { rows: b_rows };
305 let c = a.clone() - b.clone();
306
307 a -= b;
308 assert_eq!(a, c);
309 }
310 #[rstest]
311 #[case(
312 [[1.0, 2.0], [ 3.0, 4.0], [5.0, 6.0]], 0.0
313 )]
314 #[case(
315 [[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]], -91.0
316 )]
317 #[case(
318 [[1.0],[ 2.0]], 33.3
319 )]
320 #[case(
321 [[1.0, 2.0], [3.0, 4.0], [1.0, 1.0]], 84.0
322 )]
323 fn test_mul_assign<const M: usize, const N: usize>(
324 #[case] a_rows: [[f64; M]; N],
325 #[case] x: f64,
326 ) {
327 let mut a = Matrix { rows: a_rows };
328 let c = a.clone() * x;
329
330 a *= x;
331 assert_eq!(a, c);
332 }
333
334 #[rstest]
335 #[case(
336 [[1.0, 2.0], [ 3.0, 4.0], [5.0, 6.0]], 0.0
337 )]
338 #[case(
339 [[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]], -91.0
340 )]
341 #[case(
342 [[1.0],[ 2.0]], 33.3
343 )]
344 #[case(
345 [[1.0, 2.0], [3.0, 4.0], [1.0, 1.0]], 84.0
346 )]
347 fn test_mul_left<const M: usize, const N: usize>(
348 #[case] a_rows: [[f64; M]; N],
349 #[case] x: f64,
350 ) {
351 let a = Matrix { rows: a_rows };
352 assert_eq!(a.clone() * x, x * a);
353 }
354}