RsBundle  Artifact [13909cace2]

Artifact 13909cace246c1b1ec3fab08271db2465518feda:

  • File src/mpi/problem.rs — part of check-in [d576101ecd] at 2023-07-04 19:44:37 on branch mpi-cvx — Fix some clippy warnings (user: fifr size: 11608) [more...]

/*
 * Copyright (c) 2023 Frank Fischer <frank-fischer@shadow-soft.de>
 *
 * This program is free software: you can redistribute it and/or
 * modify it under the terms of the GNU General Public License as
 * published by the Free Software Foundation, either version 3 of the
 * License, or (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful, but
 * WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see  <http://www.gnu.org/licenses/>
 */

use super::msg::{ResultMsg, WorkerMsg};
use crate::problem::{
    FirstOrderProblem, ResultSender, SubgradientExtender, UpdateSendError, UpdateSender, UpdateState,
};
use crate::{DVector, Minorant, Real};

use log::{debug, error, info};
use mpi::environment::Universe;
use mpi::topology::{Communicator, SystemCommunicator};
use mpi::Rank;
use num_traits::ToPrimitive;
use serde::{Deserialize, Serialize};
use std::sync::mpsc::{channel, Sender};
use std::sync::{RwLock, RwLockReadGuard};
use thiserror::Error;
use threadpool::ThreadPool;

use crate::mpi::msg::{recv_msg, send_msg};
use serde::de::DeserializeOwned;
use std::collections::VecDeque;
use std::sync::Arc;
use std::time::Instant;

/// Error raised by the MPI [`Problem`].
#[derive(Debug, Error, Serialize, Deserialize)]
#[allow(clippy::upper_case_acronyms)]
pub enum Error<E> {
    /// MPI error.
    #[error("MPI error")]
    MPI,
    /// Original oracle error.
    #[error("Error by underlying oracle")]
    OracleError(#[source] E),
}

pub trait DistributedFirstOrderProblem: FirstOrderProblem + Send + Sync {
    /// Abstract information about a model update.
    ///
    /// The update must be a serializable representation of a model
    /// change, which can be transferred to different nodes in a
    /// distributed computing platform.
    type Update: Serialize + DeserializeOwned + Clone + Send + Sync;

    /// Compute an update of the problem.
    ///
    /// The update should be computed (e.g. which variables should be
    /// added) but not yet applied to the problem. The update will be
    /// applied later by calling [`apply_update`].
    ///
    /// The function may return immediately and compute the update in
    /// background.
    ///
    /// # Parameters
    /// - `state` is the current state of the model on which the update is based
    /// - `apply` callback to which the update should be passed
    #[allow(unused_variables)]
    fn compute_update<U>(
        &self,
        state: U,
        apply: impl FnOnce(Self::Update) -> Result<(), Self::Err> + Send + 'static,
    ) -> Result<bool, Self::Err>
    where
        U: UpdateState<<Self::Minorant as Minorant>::Primal>,
    {
        Ok(false)
    }

    /// Apply an update to the problem.
    ///
    /// The update has been computed by an preceding call to [`compute_update`].
    ///
    /// Note that the update may have been computed on a different
    /// node in a distributed computing platform.
    #[allow(unused_variables)]
    fn apply_update(&mut self, update: &Self::Update) -> Result<(), Self::Err> {
        Ok(())
    }

    /// Send an update to the main algorithm.
    ///
    /// This function is usually called after an update has been
    /// applied (to all compute nodes). It's task is to transmit the
    /// update to the main algorithm using the [`UpdateSender`]
    /// parameter `tx`.
    #[allow(unused_variables)]
    fn send_update<S>(&self, update: &Self::Update, tx: S) -> Result<(), Self::Err>
    where
        S: UpdateSender<Self> + 'static,
        Self: Sized,
    {
        Ok(())
    }
}

type ClientMessage<P> = (
    WorkerMsg<<P as DistributedFirstOrderProblem>::Update>,
    Option<Box<dyn ResultSender<Problem<P>> + 'static>>,
);

struct MPIData<P: DistributedFirstOrderProblem> {
    nclients: Rank,
    free_clients: VecDeque<Rank>,
    next_client: Rank,
    client_txs: Vec<Sender<ClientMessage<P>>>,
}

/// The first order problem for the mpi end-point on the main node.
///
/// This is the problem called from the bundle algorithm. It does not
/// solve the subproblems directly but transfers the requests to MPI
/// worker nodes. Use [`Problem`] to wrap a regular (non-MPI) first
/// order problem.
pub struct Problem<P: DistributedFirstOrderProblem> {
    universe: Universe,
    problem: Arc<RwLock<P>>,

    mpidata: Option<MPIData<P>>,
}

impl<P: DistributedFirstOrderProblem> Problem<P> {
    /// Return a reference to the underlying problem.
    pub fn problem(&self) -> RwLockReadGuard<P> {
        self.problem.read().unwrap()
    }
}

impl<P: DistributedFirstOrderProblem> Drop for Problem<P> {
    fn drop(&mut self) {
        self.mpidata.take();
    }
}

impl<P: DistributedFirstOrderProblem> Problem<P> {
    pub fn new(universe: Universe, problem: P) -> Self {
        Problem {
            universe,
            problem: Arc::new(RwLock::new(problem)),
            mpidata: None,
        }
    }
}

impl<P: DistributedFirstOrderProblem + 'static> FirstOrderProblem for Problem<P>
where
    P::Minorant: for<'a> Deserialize<'a>,
    P::Err: for<'a> Deserialize<'a>,
{
    type Err = Error<P::Err>;

    type Minorant = P::Minorant;

    fn num_variables(&self) -> usize {
        self.problem.read().unwrap().num_variables()
    }

    fn lower_bounds(&self) -> Option<Vec<Real>> {
        self.problem.read().unwrap().lower_bounds()
    }

    fn upper_bounds(&self) -> Option<Vec<Real>> {
        self.problem.read().unwrap().upper_bounds()
    }

    fn constraint_index(&self, i: usize) -> Option<usize> {
        self.problem.read().unwrap().constraint_index(i)
    }

    fn num_subproblems(&self) -> usize {
        self.problem.read().unwrap().num_subproblems()
    }

    fn start(&mut self) {
        self.problem.write().unwrap().start();
        if self.mpidata.is_none() {
            let world = self.universe.world();
            let free_clients = (1..world.size()).collect();
            let pool = ThreadPool::new(world.size().to_usize().unwrap());

            let client_txs = (1..world.size())
                .map(|rank| {
                    let (tx, rx) = channel::<ClientMessage<P>>();

                    pool.execute(move || {
                        let world = SystemCommunicator::world();
                        let client = world.process_at_rank(rank);

                        while let Ok((msg, result_tx)) = rx.recv() {
                            let start_time = Instant::now();
                            // send evaluation point
                            send_msg(&client, &msg);

                            if let WorkerMsg::ApplyUpdate(_) = msg {
                                // no response expected
                                // TODO: this might be bad because `ApplyUpdate` could fail, hence
                                // an error-message is sent
                                continue;
                            }

                            let result_tx = result_tx.unwrap();

                            // wait for response
                            loop {
                                let msg = recv_msg(&client);

                                match msg {
                                    ResultMsg::ObjectiveValue { value, .. } => result_tx.objective(value).unwrap(),
                                    ResultMsg::Minorant { minorant, .. } => result_tx.minorant(minorant).unwrap(),
                                    ResultMsg::Done { index, .. } => {
                                        debug!(
                                            "Worker index:{} time:{}",
                                            index,
                                            start_time.elapsed().as_millis() as f64 / 1000.0
                                        );
                                        break;
                                    }
                                    ResultMsg::Error { error, .. } => {
                                        result_tx.error(Error::OracleError(error)).unwrap()
                                    }
                                }
                            }
                        }

                        info!("Terminate worker thread {}", rank);

                        send_msg(&client, &WorkerMsg::<P::Update>::Terminate);
                    });

                    tx
                })
                .collect();

            self.mpidata = Some(MPIData {
                nclients: world.size() - 1,
                free_clients,
                next_client: 1,
                client_txs,
            });
        }
    }

    fn stop(&mut self) {
        self.mpidata.take();
        self.problem.write().unwrap().stop();
    }

    fn evaluate<S>(&mut self, i: usize, y: Arc<DVector>, tx: S) -> Result<(), Self::Err>
    where
        S: ResultSender<Self> + 'static,
        Self: Sized,
    {
        if self.mpidata.is_none() {
            self.start()
        }

        let mpidata = self.mpidata.as_mut().unwrap();

        // get client
        let rank = mpidata.free_clients.pop_front().unwrap_or_else(|| {
            let r = mpidata.next_client;
            let n = mpidata.nclients; // rank 0 is this node
            mpidata.next_client = (mpidata.next_client % n) + 1;
            r
        }) as usize;

        mpidata.client_txs[rank - 1]
            .send((WorkerMsg::Evaluate { i, y }, Some(Box::new(tx))))
            .unwrap();

        Ok(())
    }

    fn update<U, S>(&mut self, state: U, tx: S) -> Result<(), Self::Err>
    where
        U: UpdateState<<Self::Minorant as Minorant>::Primal>,
        S: UpdateSender<Self> + 'static,
        Self: Sized,
    {
        if self.mpidata.is_none() {
            self.start()
        }

        let mpidata = self.mpidata.as_mut().unwrap();

        let problem = self.problem.clone();
        let client_txs = mpidata.client_txs.clone();
        self.problem
            .write()
            .unwrap()
            .compute_update(state, move |update| {
                let update = Arc::new(update);
                let world = SystemCommunicator::world();
                for rank in 1..world.size() as usize {
                    client_txs[rank - 1]
                        .send((WorkerMsg::ApplyUpdate(update.clone()), None))
                        .unwrap()
                }
                let mut problem = problem.write().unwrap();
                problem.apply_update(&update)?;
                problem.send_update(&update, WorkerUpdateSender(tx))?;

                Ok(())
            })
            .map_err(Error::OracleError)?;

        Ok(())
    }
}

struct WorkerUpdateSender<S>(S);

impl<P, S> UpdateSender<P> for WorkerUpdateSender<S>
where
    P: DistributedFirstOrderProblem + 'static,
    P::Minorant: for<'a> Deserialize<'a>,
    P::Err: for<'a> Deserialize<'a>,
    S: UpdateSender<Problem<P>>,
{
    fn add_variables(
        &self,
        bounds: Vec<(Real, Real)>,
        sgext: Box<dyn SubgradientExtender<P::Minorant, P::Err>>,
    ) -> Result<(), UpdateSendError> {
        self.0.add_variables(
            bounds,
            Box::new(move |fidx: usize, m: &mut P::Minorant| {
                sgext.extend_subgradient(fidx, m).map_err(Error::OracleError)
            }),
        )
    }

    fn error(&self, err: P::Err) -> Result<(), UpdateSendError> {
        self.0.error(Error::OracleError(err))
    }
}