1use serde::{Deserialize, Serialize};
5use serde_with::serde_as;
6use std::ops::{Add, AddAssign, Index, IndexMut, Mul, MulAssign, Neg, Sub, SubAssign};
7
8use crate::{GeneralMatrix, matrix::Matrix};
9
10#[serde_as]
28#[derive(Clone, Copy, Debug, PartialEq, Serialize, Deserialize)]
29pub struct DiagonalMatrix<const N: usize> {
30 #[serde_as(as = "[_; N]")]
32 pub elements: [f64; N],
33}
34
35impl<const N: usize> Index<usize> for DiagonalMatrix<N> {
36 type Output = f64;
37
38 #[inline]
51 fn index(&self, index: usize) -> &f64 {
52 &self.elements[index]
53 }
54}
55impl<const N: usize> IndexMut<usize> for DiagonalMatrix<N> {
56 #[inline]
57 fn index_mut(&mut self, index: usize) -> &mut f64 {
58 &mut self.elements[index]
59 }
60}
61
62impl<const N: usize> Index<(usize, usize)> for DiagonalMatrix<N> {
63 type Output = f64;
64
65 #[inline]
79 fn index(&self, index: (usize, usize)) -> &f64 {
80 let (i, j) = index;
81 if i == j { &self.elements[i] } else { &0.0 }
82 }
83}
84
85impl<const N: usize> GeneralMatrix for DiagonalMatrix<N> {
86 #[inline]
87 fn zeros() -> Self {
88 Self {
89 elements: std::array::from_fn(|_| 0.0),
90 }
91 }
92 #[inline]
93 fn shape(&self) -> (usize, usize) {
94 (N, N)
95 }
96}
97
98impl<const N: usize> DiagonalMatrix<N> {
99 #[must_use]
112 #[inline]
113 pub fn to_dense(self) -> Matrix<N, N> {
114 Matrix {
115 rows: std::array::from_fn(|i| {
116 std::array::from_fn(|j| if i == j { self.elements[i] } else { 0.0 })
117 }),
118 }
119 }
120}
121
122impl<const N: usize> Mul<f64> for DiagonalMatrix<N> {
123 type Output = Self;
124
125 #[inline]
140 fn mul(self, rhs: f64) -> Self {
141 Self {
142 elements: self.elements.map(|r| r * rhs),
143 }
144 }
145}
146
147impl<const N: usize> Neg for DiagonalMatrix<N> {
148 type Output = Self;
149
150 #[inline]
165 fn neg(self) -> Self {
166 Self {
167 elements: self.elements.map(|r| -r),
168 }
169 }
170}
171
172impl<const N: usize> Add<Self> for DiagonalMatrix<N> {
173 type Output = Self;
174
175 #[inline]
193 fn add(self, rhs: Self) -> Self {
194 Self {
195 elements: std::array::from_fn(|i| self[i] + rhs[i]),
196 }
197 }
198}
199impl<const N: usize> Sub<Self> for DiagonalMatrix<N> {
200 type Output = Self;
201
202 #[inline]
220 fn sub(self, rhs: Self) -> Self {
221 Self {
222 elements: std::array::from_fn(|i| self[i] - rhs[i]),
223 }
224 }
225}
226
227impl<const N: usize> AddAssign for DiagonalMatrix<N> {
228 #[inline]
229 fn add_assign(&mut self, rhs: Self) {
230 self.elements
231 .iter_mut()
232 .zip(rhs.elements.iter())
233 .for_each(|(x, r)| *x += r);
234 }
235}
236
237impl<const N: usize> MulAssign<f64> for DiagonalMatrix<N> {
238 #[inline]
239 fn mul_assign(&mut self, rhs: f64) {
240 self.elements.iter_mut().for_each(|x| *x *= rhs);
241 }
242}
243
244impl<const N: usize> SubAssign for DiagonalMatrix<N> {
245 #[inline]
246 fn sub_assign(&mut self, rhs: Self) {
247 self.elements
248 .iter_mut()
249 .zip(rhs.elements.iter())
250 .for_each(|(x, r)| *x -= r);
251 }
252}
253
254#[cfg(test)]
255mod tests {
256 use approxim::assert_ulps_eq;
257 use rstest::rstest;
258 use std::ops::Index;
259
260 use super::*;
261 use crate::GeneralMatrix;
262
263 fn assert_diags_ulps_eq<const N: usize>(
264 m0: &DiagonalMatrix<N>,
265 m1: &impl Index<usize, Output = f64>,
266 ) {
267 for i in 0..N {
268 assert_ulps_eq!(m0[i], m1[i], epsilon = 1e-13);
269 }
270 }
271
272 #[test]
273 fn test_diagonal_matrix_add_n2() {
274 let a_diag = [1.0, 2.0];
275 let b_diag = [3.0, 4.0];
276 let a = DiagonalMatrix { elements: a_diag };
277 let b = DiagonalMatrix { elements: b_diag };
278 let expected: Vec<f64> = a_diag
279 .iter()
280 .zip(b_diag.iter())
281 .map(|(x, y)| x + y)
282 .collect();
283 let custom_sum = a + b;
284 assert_diags_ulps_eq(&custom_sum, &expected);
285 }
286
287 #[test]
288 fn test_diagonal_matrix_add_n3() {
289 let a_diag = [1.0, 2.0, 3.0];
290 let b_diag = [4.0, 5.0, 6.0];
291 let a = DiagonalMatrix { elements: a_diag };
292 let b = DiagonalMatrix { elements: b_diag };
293 let expected: Vec<f64> = a_diag
294 .iter()
295 .zip(b_diag.iter())
296 .map(|(x, y)| x + y)
297 .collect();
298 let custom_sum = a + b;
299 assert_diags_ulps_eq(&custom_sum, &expected);
300 }
301
302 #[test]
303 fn test_diagonal_matrix_add_assign_n3() {
304 let a_diag = [1.0, 2.0, 3.0];
305 let b_diag = [4.0, 5.0, 6.0];
306 let mut a = DiagonalMatrix { elements: a_diag };
307 let b = DiagonalMatrix { elements: b_diag };
308 let expected: Vec<f64> = a_diag
309 .iter()
310 .zip(b_diag.iter())
311 .map(|(x, y)| x + y)
312 .collect();
313 a += b;
314 assert_diags_ulps_eq(&a, &expected);
315 }
316
317 #[test]
318 fn test_diagonal_matrix_sub_n2() {
319 let a_diag = [1.0, 2.0];
320 let b_diag = [3.0, 4.0];
321 let a = DiagonalMatrix { elements: a_diag };
322 let b = DiagonalMatrix { elements: b_diag };
323 let expected: Vec<f64> = a_diag
324 .iter()
325 .zip(b_diag.iter())
326 .map(|(x, y)| x - y)
327 .collect();
328 let custom_sub = a - b;
329 assert_diags_ulps_eq(&custom_sub, &expected);
330 }
331
332 #[test]
333 fn test_diagonal_matrix_sub_n3() {
334 let a_diag = [1.0, 2.0, 3.0];
335 let b_diag = [4.0, 5.0, 6.0];
336 let a = DiagonalMatrix { elements: a_diag };
337 let b = DiagonalMatrix { elements: b_diag };
338 let expected: Vec<f64> = a_diag
339 .iter()
340 .zip(b_diag.iter())
341 .map(|(x, y)| x - y)
342 .collect();
343 let custom_sub = a - b;
344 assert_diags_ulps_eq(&custom_sub, &expected);
345 }
346
347 #[test]
348 fn test_diagonal_matrix_sub_assign_n3() {
349 let a_diag = [1.0, 2.0, 3.0];
350 let b_diag = [4.0, 5.0, 6.0];
351 let mut a = DiagonalMatrix { elements: a_diag };
352 let b = DiagonalMatrix { elements: b_diag };
353 let expected: Vec<f64> = a_diag
354 .iter()
355 .zip(b_diag.iter())
356 .map(|(x, y)| x - y)
357 .collect();
358 a -= b;
359 assert_diags_ulps_eq(&a, &expected);
360 }
361
362 #[test]
363 fn test_diagonal_matrix_neg_n2() {
364 let diag = [1.0, -2.0];
365 let matrix = DiagonalMatrix { elements: diag };
366 let expected: Vec<f64> = diag.iter().map(|x| -x).collect();
367 let custom_neg = -matrix;
368 assert_diags_ulps_eq(&custom_neg, &expected);
369 }
370
371 #[test]
372 fn test_diagonal_matrix_neg_n3() {
373 let diag = [1.0, -2.0, 0.0];
374 let matrix = DiagonalMatrix { elements: diag };
375 let expected: Vec<f64> = diag.iter().map(|x| -x).collect();
376 let custom_neg = -matrix;
377 assert_diags_ulps_eq(&custom_neg, &expected);
378 }
379
380 #[rstest]
381 #[case([1.0, 2.0], 5.0)]
382 #[case([1.0, 2.0], -1.0)]
383 #[case([1.0, 2.0], 0.0)]
384 fn test_diagonal_matrix_scalar_mul_n2(#[case] diag: [f64; 2], #[case] scalar: f64) {
385 let matrix = DiagonalMatrix { elements: diag };
386 let expected: Vec<f64> = diag.iter().map(|x| x * scalar).collect();
387 let custom_mul = matrix * scalar;
388 assert_diags_ulps_eq(&custom_mul, &expected);
389 }
390
391 #[rstest]
392 #[case([1.0, 2.0], 5.0)]
393 #[case([1.0, 2.0], -1.0)]
394 #[case([1.0, 2.0], 0.0)]
395 fn test_diagonal_matrix_scalar_mul_assign_n2(#[case] diag: [f64; 2], #[case] scalar: f64) {
396 let mut matrix = DiagonalMatrix { elements: diag };
397 let expected: Vec<f64> = diag.iter().map(|x| x * scalar).collect();
398 matrix *= scalar;
399 assert_diags_ulps_eq(&matrix, &expected);
400 }
401
402 #[test]
403 fn test_indexing() {
404 let diag_mat = DiagonalMatrix::<3> {
406 elements: [1.0, 2.0, 3.0],
407 };
408 assert_eq!(diag_mat[1], 2.0); assert_eq!(diag_mat[(2, 2)], 3.0); assert_eq!(diag_mat[(0, 1)], 0.0); }
412
413 #[test]
414 fn test_general_matrix_methods() {
415 let diag_zeros = DiagonalMatrix::<4>::zeros();
417 for i in 0..4 {
418 assert_eq!(diag_zeros[i], 0.0);
419 }
420 }
421}