RsBundle  Artifact [f9f76240bb]

Artifact f9f76240bb660bc8e15b61e09816858e80392af3:

  • File src/data/raw.rs — part of check-in [878867274d] at 2020-07-20 16:01:28 on branch minorant-trait — Move blas vector operations to `data::raw` (user: fifr size: 2521) [more...]

/*
 * Copyright (c) 2020 Frank Fischer <frank-fischer@shadow-soft.de>
 *
 * This program is free software: you can redistribute it and/or
 * modify it under the terms of the GNU General Public License as
 * published by the Free Software Foundation, either version 3 of the
 * License, or (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful, but
 * WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see  <http://www.gnu.org/licenses/>
 */

///! BLAS-like low-level vector routines.

#[cfg(feature = "blas")]
use {openblas_src as _, rs_blas as blas, std::os::raw::c_int};

pub trait BLAS1<T> {
    /// Compute the inner product.
    fn dot(&self, y: &[T]) -> T;

    /// Compute the inner product.
    ///
    /// The inner product is computed on the smaller of the two
    /// dimensions. All other elements are assumed to be zero.
    fn dot_begin(&self, y: &[T]) -> T;

    /// Compute `self = self + alpha * y`.
    fn add_scaled(&mut self, alpha: T, y: &[T]);

    /// Return the 2-norm of this vector.
    fn norm2(&self) -> T;
}

impl BLAS1<f64> for [f64] {
    fn dot(&self, other: &[f64]) -> f64 {
        debug_assert_eq!(self.len(), other.len(), "Vectors must have the same size");
        Self::dot_begin(self, other)
    }

    fn dot_begin(&self, other: &[f64]) -> f64 {
        #[cfg(feature = "blas")]
        unsafe {
            blas::ddot(self.len().min(other.len()) as c_int, &self, 1, &other, 1)
        }
        #[cfg(not(feature = "blas"))]
        {
            self.iter().zip(other.iter()).map(|(x, y)| x * y).sum::<f64>()
        }
    }

    fn add_scaled(&mut self, alpha: f64, y: &[f64]) {
        assert_eq!(self.len(), y.len());
        #[cfg(feature = "blas")]
        unsafe {
            blas::daxpy(self.len() as c_int, alpha, &y, 1, &mut self[..], 1)
        }
        #[cfg(not(feature = "blas"))]
        {
            for (x, y) in self.iter_mut().zip(y.iter()) {
                *x += alpha * y;
            }
        }
    }

    fn norm2(&self) -> f64 {
        #[cfg(feature = "blas")]
        unsafe {
            blas::dnrm2(self.len() as c_int, &self, 1)
        }
        #[cfg(not(feature = "blas"))]
        {
            self.iter().map(|x| x * x).sum::<f64>().sqrt()
        }
    }
}