RsBundle  Artifact [cd731597ed]

Artifact cd731597ed9df028c3fa1e4f8c307c36e7c5797e:

  • File src/mpi/problem.rs — part of check-in [82ae17b6df] at 2023-04-05 20:40:18 on branch mpi — mpi::problem: remove thread-pool instance variable. It has been unused. (user: fifr size: 10417)

/*
 * Copyright (c) 2023 Frank Fischer <frank-fischer@shadow-soft.de>
 *
 * This program is free software: you can redistribute it and/or
 * modify it under the terms of the GNU General Public License as
 * published by the Free Software Foundation, either version 3 of the
 * License, or (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful, but
 * WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see  <http://www.gnu.org/licenses/>
 */

use super::msg::{ResultMsg, ResultType, WorkerMsg};
use crate::problem::{
    FirstOrderProblem, ResultSender, SubgradientExtender, UpdateSendError, UpdateSender, UpdateState,
};
use crate::{DVector, Minorant, Real};

use flexbuffers;
use log::{error, info};
use mpi::environment::Universe;
use mpi::point_to_point::{Destination, Source};
use mpi::topology::{Communicator, SystemCommunicator};
use mpi::Rank;
use num_traits::ToPrimitive;
use serde::{Deserialize, Serialize};
use std::sync::mpsc::{channel, Sender};
use std::sync::RwLock;
use thiserror::Error;
use threadpool::ThreadPool;

use serde::de::DeserializeOwned;
use std::collections::VecDeque;
use std::sync::Arc;

/// Error raised by the MPI [`Problem`].
#[derive(Debug, Error)]
pub enum Error<E> {
    /// MPI error.
    #[error("MPI error")]
    MPI,
    /// Error from the remote oracle.
    #[error("Error raised on remote host")]
    Remote,
    /// Original oracle error.
    #[error("Error by underlying oracle")]
    OracleError(E),
}

pub trait DistributedFirstOrderProblem: FirstOrderProblem + Send + Sync {
    type Update: Serialize + DeserializeOwned + Clone + Send + Sync;

    #[allow(unused_variables)]
    fn create_update<U>(
        &self,
        state: U,
        apply: impl FnOnce(Self::Update) -> Result<(), Self::Err> + Send + 'static,
    ) -> Result<bool, Self::Err>
    where
        U: UpdateState<<Self::Minorant as Minorant>::Primal>,
    {
        Ok(false)
    }

    #[allow(unused_variables)]
    fn apply_update(&mut self, update: &Self::Update) -> Result<(), Self::Err> {
        Ok(())
    }

    #[allow(unused_variables)]
    fn send_update<S>(&self, update: &Self::Update, tx: S) -> Result<(), Self::Err>
    where
        S: UpdateSender<Self> + 'static,
        Self: Sized,
    {
        Ok(())
    }
}

type ClientMessage<P> = (
    WorkerMsg<<P as DistributedFirstOrderProblem>::Update>,
    Option<Box<dyn ResultSender<Problem<P>> + 'static>>,
);

struct MPIData<P: DistributedFirstOrderProblem> {
    nclients: Rank,
    free_clients: VecDeque<Rank>,
    next_client: Rank,
    client_txs: Vec<Sender<ClientMessage<P>>>,
}

/// The first order problem for the mpi end-point on the main node.
///
/// This is the problem called from the bundle algorithm. It does not
/// solve the subproblems directly but transfers the requests to MPI
/// worker nodes. Use [`Problem`] to wrap a regular (non-MPI) first
/// order problem.
pub struct Problem<P: DistributedFirstOrderProblem> {
    universe: Universe,
    problem: Arc<RwLock<P>>,

    mpidata: Option<MPIData<P>>,
}

impl<P: DistributedFirstOrderProblem> Drop for Problem<P> {
    fn drop(&mut self) {
        self.mpidata.take();
    }
}

impl<P: DistributedFirstOrderProblem> Problem<P> {
    pub fn new(universe: Universe, problem: P) -> Self {
        Problem {
            universe,
            problem: Arc::new(RwLock::new(problem)),
            mpidata: None,
        }
    }
}

impl<P: DistributedFirstOrderProblem + 'static> FirstOrderProblem for Problem<P>
where
    P::Minorant: for<'a> Deserialize<'a>,
{
    type Err = Error<P::Err>;

    type Minorant = P::Minorant;

    fn num_variables(&self) -> usize {
        self.problem.read().unwrap().num_variables()
    }

    fn lower_bounds(&self) -> Option<Vec<Real>> {
        self.problem.read().unwrap().lower_bounds()
    }

    fn upper_bounds(&self) -> Option<Vec<Real>> {
        self.problem.read().unwrap().upper_bounds()
    }

    fn num_subproblems(&self) -> usize {
        self.problem.read().unwrap().num_subproblems()
    }

    fn start(&mut self) {
        self.problem.write().unwrap().start();
        if self.mpidata.is_none() {
            let world = self.universe.world();
            let free_clients = (1..world.size()).collect();
            let pool = ThreadPool::new(world.size().to_usize().unwrap());

            let client_txs = (1..world.size())
                .map(|rank| {
                    let (tx, rx) = channel::<ClientMessage<P>>();

                    pool.execute(move || {
                        let world = SystemCommunicator::world();
                        let client = world.process_at_rank(rank);

                        while let Ok((msg, result_tx)) = rx.recv() {
                            // send evaluation point
                            let mut s = flexbuffers::FlexbufferSerializer::new();
                            msg.serialize(&mut s).unwrap();
                            client.send(s.view());

                            if let WorkerMsg::Update(_) = msg {
                                // no response expected
                                continue;
                            }

                            let result_tx = result_tx.unwrap();

                            // wait for response
                            loop {
                                let (msg, _) = client.receive::<ResultMsg>();
                                match msg.typ {
                                    ResultType::ObjectiveValue => {
                                        let (obj, _) = client.receive::<Real>();
                                        result_tx.objective(obj).unwrap();
                                    }
                                    ResultType::Minorant => {
                                        let mut raw = vec![0u8; msg.n];
                                        client.receive_into(&mut raw[..]);

                                        let r = flexbuffers::Reader::get_root(&raw[..]).unwrap();
                                        let minorant = Self::Minorant::deserialize(r).unwrap();
                                        result_tx.minorant(minorant).unwrap();
                                    }
                                    ResultType::Done => {
                                        break;
                                    }
                                    ResultType::Error => {
                                        result_tx.error(Error::Remote).unwrap();
                                    }
                                }
                            }
                        }

                        info!("Terminate worker thread {}", rank);
                        let mut s = flexbuffers::FlexbufferSerializer::new();
                        WorkerMsg::<P::Update>::Terminate.serialize(&mut s).unwrap();
                        client.send(s.view());
                    });

                    tx
                })
                .collect();

            self.mpidata = Some(MPIData {
                nclients: world.size() - 1,
                free_clients,
                next_client: 1,
                client_txs,
            });
        }
    }

    fn stop(&mut self) {
        self.mpidata.take();
        self.problem.write().unwrap().stop();
    }

    fn evaluate<S>(&mut self, i: usize, y: Arc<DVector>, tx: S) -> Result<(), Self::Err>
    where
        S: ResultSender<Self> + 'static,
        Self: Sized,
    {
        if self.mpidata.is_none() {
            self.start()
        }

        let mpidata = self.mpidata.as_mut().unwrap();

        // get client
        let rank = mpidata.free_clients.pop_front().unwrap_or_else(|| {
            let r = mpidata.next_client;
            let n = mpidata.nclients; // rank 0 is this node
            mpidata.next_client = (mpidata.next_client % n) + 1;
            r
        }) as usize;

        mpidata.client_txs[rank - 1]
            .send((WorkerMsg::Evaluate { i, y }, Some(Box::new(tx))))
            .unwrap();

        Ok(())
    }

    fn update<U, S>(&mut self, state: U, tx: S) -> Result<(), Self::Err>
    where
        U: UpdateState<<Self::Minorant as Minorant>::Primal>,
        S: UpdateSender<Self> + 'static,
        Self: Sized,
    {
        if self.mpidata.is_none() {
            self.start()
        }

        let mpidata = self.mpidata.as_mut().unwrap();

        let problem = self.problem.clone();
        let client_txs = mpidata.client_txs.clone();
        if !self
            .problem
            .write()
            .unwrap()
            .create_update(state, move |update| {
                let update = Arc::new(update);
                let world = SystemCommunicator::world();
                for rank in 1..world.size() as usize {
                    client_txs[rank - 1]
                        .send((WorkerMsg::Update(update.clone()), None))
                        .unwrap()
                }
                let mut problem = problem.write().unwrap();
                problem.apply_update(&update).map_err(|_| "Error".to_string()).unwrap();
                problem
                    .send_update(&update, WorkerUpdateSender(tx))
                    .map_err(|_| "Error2".to_string())
                    .unwrap();

                Ok(())
            })
            .map_err(Error::OracleError)?
        {
            return Ok(());
        }

        Ok(())
    }
}

struct WorkerUpdateSender<S>(S);

impl<P, S> UpdateSender<P> for WorkerUpdateSender<S>
where
    P: DistributedFirstOrderProblem + 'static,
    P::Minorant: for<'a> Deserialize<'a>,
    S: UpdateSender<Problem<P>>,
{
    fn add_variables(
        &self,
        bounds: Vec<(Real, Real)>,
        sgext: Box<dyn SubgradientExtender<P::Minorant, P::Err>>,
    ) -> Result<(), UpdateSendError> {
        self.0.add_variables(
            bounds,
            Box::new(move |fidx: usize, m: &mut P::Minorant| {
                sgext.extend_subgradient(fidx, m).map_err(Error::OracleError)
            }),
        )
    }

    fn error(&self, err: P::Err) -> Result<(), UpdateSendError> {
        self.0.error(Error::OracleError(err))
    }
}