1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
use std::fmt::Debug;
use std::ops::{Add, Sub, Mul, Div, AddAssign, SubAssign, MulAssign, DivAssign};
use num::Zero;

/// A trait for types that act as scalars in mathematical operations, like real numbers or complex numbers.
/// I concider a scalar, any struct which contains all the traits listed bellow.
/// f32 is an example if the Scalar implementation.
/// In the future a Complex Number struct implementation could potentially fit and allign with all the traits listed bellow.
pub trait Scalar: Copy + Clone + Debug + Zero
    + Add<Output = Self> 
    + Sub<Output = Self> 
    + Mul<Output = Self> 
    + Div<Output = Self>
    + AddAssign
    + SubAssign
    + MulAssign
    + DivAssign
    + PartialOrd
    + PartialEq
{
    /// Fused Multiply-Add (FMA): Default implementation that computes `a * b + c`.
    fn fma(a: Self, b: Self, c: Self) -> Self {
        a * b + c
    }


    /// Fused Multiply-Sub (FMS): Default implementation that computes `a * b - c`.
    fn fms(a: Self, b: Self, c: Self) -> Self {
        a * b - c
    }

    fn to_f32(self) -> f32;

    fn from_f32(value: f32) -> Self;
}

// Specialize for f32 using SIMD FMA intrinsics
#[cfg(target_arch = "x86_64")]
use core::arch::x86_64::{
    _mm_fmadd_ps, _mm_set1_ps, _mm_cvtss_f32
};

impl Scalar for f32 {
    fn fma(a: f32, b: f32, c: f32) -> f32 {
        unsafe {
            // Use SIMD registers to do FMA: a * b + c
            let vec_a = _mm_set1_ps(a);
            let vec_b = _mm_set1_ps(b);
            let vec_c = _mm_set1_ps(c);
            let result = _mm_fmadd_ps(vec_a, vec_b, vec_c);
            _mm_cvtss_f32(result) // Return the first element of the SIMD result
        }
    }

    fn fms(a: f32, b: f32, c: f32) -> f32 {
        unsafe {
            // Use SIMD registers to do FMA: a * b + c
            let vec_a = _mm_set1_ps(a);
            let vec_b = _mm_set1_ps(b);
            let vec_c = _mm_set1_ps(-c);
            let result = _mm_fmadd_ps(vec_a, vec_b, vec_c);
            _mm_cvtss_f32(result) // Return the first element of the SIMD result
        }
    }

    fn from_f32(value: f32) -> Self {
        value
    }

    fn to_f32(self) -> f32 {
        self as f32
    }
}

// Default behavior (no specialization) for `f64` `i32`, `i64`, `u32`, `u64`, etc.
impl Scalar for f64 {
    fn from_f32(value: f32) -> Self {
        value as f64
    }

    fn to_f32(self) -> f32 {
        self as f32
    }
}

impl Scalar for i32 {
    fn from_f32(value: f32) -> Self {
        value as i32
    }

    fn to_f32(self) -> f32 {
        self as f32
    }
}

impl Scalar for i64 {
    fn from_f32(value: f32) -> Self {
        value as i64
    }

    fn to_f32(self) -> f32 {
        self as f32
    }
}

impl Scalar for u32 {
    fn from_f32(value: f32) -> Self {
        value as u32
    }

    fn to_f32(self) -> f32 {
        self as f32
    }
}

impl Scalar for u64 {
    fn from_f32(value: f32) -> Self {
        value as u64
    }

    fn to_f32(self) -> f32 {
        self as f32
    }
}