/*
* Copyright (c) 2020 Frank Fischer <frank-fischer@shadow-soft.de>
*
* This program is free software: you can redistribute it and/or
* modify it under the terms of the GNU General Public License as
* published by the Free Software Foundation, either version 3 of the
* License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful, but
* WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>
*/
///! BLAS-like low-level vector routines.
#[cfg(feature = "blas")]
use {openblas_src as _, rs_blas as blas, std::os::raw::c_int};
pub trait BLAS1<T> {
/// Compute the inner product.
fn dot(&self, y: &[T]) -> T;
/// Compute the inner product.
///
/// The inner product is computed on the smaller of the two
/// dimensions. All other elements are assumed to be zero.
fn dot_begin(&self, y: &[T]) -> T;
/// Compute `self = self + alpha * y`.
fn add_scaled(&mut self, alpha: T, y: &[T]);
/// Return the 2-norm of this vector.
fn norm2(&self) -> T;
}
impl BLAS1<f64> for [f64] {
fn dot(&self, other: &[f64]) -> f64 {
debug_assert_eq!(self.len(), other.len(), "Vectors must have the same size");
Self::dot_begin(self, other)
}
fn dot_begin(&self, other: &[f64]) -> f64 {
#[cfg(feature = "blas")]
unsafe {
blas::ddot(self.len().min(other.len()) as c_int, &self, 1, &other, 1)
}
#[cfg(not(feature = "blas"))]
{
self.iter().zip(other.iter()).map(|(x, y)| x * y).sum::<f64>()
}
}
fn add_scaled(&mut self, alpha: f64, y: &[f64]) {
assert_eq!(self.len(), y.len());
#[cfg(feature = "blas")]
unsafe {
blas::daxpy(self.len() as c_int, alpha, &y, 1, &mut self[..], 1)
}
#[cfg(not(feature = "blas"))]
{
for (x, y) in self.iter_mut().zip(y.iter()) {
*x += alpha * y;
}
}
}
fn norm2(&self) -> f64 {
#[cfg(feature = "blas")]
unsafe {
blas::dnrm2(self.len() as c_int, &self, 1)
}
#[cfg(not(feature = "blas"))]
{
self.iter().map(|x| x * x).sum::<f64>().sqrt()
}
}
}