/*
* Copyright (c) 2023 Frank Fischer <frank-fischer@shadow-soft.de>
*
* This program is free software: you can redistribute it and/or
* modify it under the terms of the GNU General Public License as
* published by the Free Software Foundation, either version 3 of the
* License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful, but
* WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>
*/
use super::msg::{ResultMsg, WorkerMsg};
use crate::problem::{
FirstOrderProblem, ResultSender, SubgradientExtender, UpdateSendError, UpdateSender, UpdateState,
};
use crate::{DVector, Minorant, Real};
use log::{error, info};
use mpi::environment::Universe;
use mpi::topology::{Communicator, SystemCommunicator};
use mpi::Rank;
use num_traits::ToPrimitive;
use serde::{Deserialize, Serialize};
use std::sync::mpsc::{channel, Sender};
use std::sync::RwLock;
use thiserror::Error;
use threadpool::ThreadPool;
use crate::mpi::msg::{recv_msg, send_msg};
use serde::de::DeserializeOwned;
use std::collections::VecDeque;
use std::sync::Arc;
/// Error raised by the MPI [`Problem`].
#[derive(Debug, Error)]
pub enum Error<E> {
/// MPI error.
#[error("MPI error")]
MPI,
/// Original oracle error.
#[error("Error by underlying oracle")]
OracleError(#[source] E),
}
pub trait DistributedFirstOrderProblem: FirstOrderProblem + Send + Sync {
/// Abstract information about a model update.
///
/// The update must be a serializable representation of a model
/// change, which can be transferred to different nodes in a
/// distributed computing platform.
type Update: Serialize + DeserializeOwned + Clone + Send + Sync;
/// Compute an update of the problem.
///
/// The update should be computed (e.g. which variables should be
/// added) but not yet applied to the problem. The update will be
/// applied later by calling [`apply_update`].
///
/// The function may return immediately and compute the update in
/// background.
///
/// # Parameters
/// - `state` is the current state of the model on which the update is based
/// - `apply` callback to which the update should be passed
#[allow(unused_variables)]
fn compute_update<U>(
&self,
state: U,
apply: impl FnOnce(Self::Update) -> Result<(), Self::Err> + Send + 'static,
) -> Result<bool, Self::Err>
where
U: UpdateState<<Self::Minorant as Minorant>::Primal>,
{
Ok(false)
}
/// Apply an update to the problem.
///
/// The update has been computed by an preceding call to [`compute_update`].
///
/// Note that the update may have been computed on a different
/// node in a distributed computing platform.
#[allow(unused_variables)]
fn apply_update(&mut self, update: &Self::Update) -> Result<(), Self::Err> {
Ok(())
}
/// Send an update to the main algorithm.
///
/// This function is usually called after an update has been
/// applied (to all compute nodes). It's task is to transmit the
/// update to the main algorithm using the [`UpdateSender`]
/// parameter `tx`.
#[allow(unused_variables)]
fn send_update<S>(&self, update: &Self::Update, tx: S) -> Result<(), Self::Err>
where
S: UpdateSender<Self> + 'static,
Self: Sized,
{
Ok(())
}
}
type ClientMessage<P> = (
WorkerMsg<<P as DistributedFirstOrderProblem>::Update>,
Option<Box<dyn ResultSender<Problem<P>> + 'static>>,
);
struct MPIData<P: DistributedFirstOrderProblem> {
nclients: Rank,
free_clients: VecDeque<Rank>,
next_client: Rank,
client_txs: Vec<Sender<ClientMessage<P>>>,
}
/// The first order problem for the mpi end-point on the main node.
///
/// This is the problem called from the bundle algorithm. It does not
/// solve the subproblems directly but transfers the requests to MPI
/// worker nodes. Use [`Problem`] to wrap a regular (non-MPI) first
/// order problem.
pub struct Problem<P: DistributedFirstOrderProblem> {
universe: Universe,
problem: Arc<RwLock<P>>,
mpidata: Option<MPIData<P>>,
}
impl<P: DistributedFirstOrderProblem> Drop for Problem<P> {
fn drop(&mut self) {
self.mpidata.take();
}
}
impl<P: DistributedFirstOrderProblem> Problem<P> {
pub fn new(universe: Universe, problem: P) -> Self {
Problem {
universe,
problem: Arc::new(RwLock::new(problem)),
mpidata: None,
}
}
}
impl<P: DistributedFirstOrderProblem + 'static> FirstOrderProblem for Problem<P>
where
P::Minorant: for<'a> Deserialize<'a>,
P::Err: for<'a> Deserialize<'a>,
{
type Err = Error<P::Err>;
type Minorant = P::Minorant;
fn num_variables(&self) -> usize {
self.problem.read().unwrap().num_variables()
}
fn lower_bounds(&self) -> Option<Vec<Real>> {
self.problem.read().unwrap().lower_bounds()
}
fn upper_bounds(&self) -> Option<Vec<Real>> {
self.problem.read().unwrap().upper_bounds()
}
fn num_subproblems(&self) -> usize {
self.problem.read().unwrap().num_subproblems()
}
fn start(&mut self) {
self.problem.write().unwrap().start();
if self.mpidata.is_none() {
let world = self.universe.world();
let free_clients = (1..world.size()).collect();
let pool = ThreadPool::new(world.size().to_usize().unwrap());
let client_txs = (1..world.size())
.map(|rank| {
let (tx, rx) = channel::<ClientMessage<P>>();
pool.execute(move || {
let world = SystemCommunicator::world();
let client = world.process_at_rank(rank);
while let Ok((msg, result_tx)) = rx.recv() {
// send evaluation point
send_msg(&client, &msg);
if let WorkerMsg::ApplyUpdate(_) = msg {
// no response expected
// TODO: this might be bad because `ApplyUpdate` could fail, hence
// an error-message is sent
continue;
}
let result_tx = result_tx.unwrap();
// wait for response
loop {
let msg = recv_msg(&client);
match msg {
ResultMsg::ObjectiveValue { value, .. } => result_tx.objective(value).unwrap(),
ResultMsg::Minorant { minorant, .. } => result_tx.minorant(minorant).unwrap(),
ResultMsg::Done { .. } => break,
ResultMsg::Error { error, .. } => {
result_tx.error(Error::OracleError(error)).unwrap()
}
}
}
}
info!("Terminate worker thread {}", rank);
send_msg(&client, &WorkerMsg::<P::Update>::Terminate);
});
tx
})
.collect();
self.mpidata = Some(MPIData {
nclients: world.size() - 1,
free_clients,
next_client: 1,
client_txs,
});
}
}
fn stop(&mut self) {
self.mpidata.take();
self.problem.write().unwrap().stop();
}
fn evaluate<S>(&mut self, i: usize, y: Arc<DVector>, tx: S) -> Result<(), Self::Err>
where
S: ResultSender<Self> + 'static,
Self: Sized,
{
if self.mpidata.is_none() {
self.start()
}
let mpidata = self.mpidata.as_mut().unwrap();
// get client
let rank = mpidata.free_clients.pop_front().unwrap_or_else(|| {
let r = mpidata.next_client;
let n = mpidata.nclients; // rank 0 is this node
mpidata.next_client = (mpidata.next_client % n) + 1;
r
}) as usize;
mpidata.client_txs[rank - 1]
.send((WorkerMsg::Evaluate { i, y }, Some(Box::new(tx))))
.unwrap();
Ok(())
}
fn update<U, S>(&mut self, state: U, tx: S) -> Result<(), Self::Err>
where
U: UpdateState<<Self::Minorant as Minorant>::Primal>,
S: UpdateSender<Self> + 'static,
Self: Sized,
{
if self.mpidata.is_none() {
self.start()
}
let mpidata = self.mpidata.as_mut().unwrap();
let problem = self.problem.clone();
let client_txs = mpidata.client_txs.clone();
self.problem
.write()
.unwrap()
.compute_update(state, move |update| {
let update = Arc::new(update);
let world = SystemCommunicator::world();
for rank in 1..world.size() as usize {
client_txs[rank - 1]
.send((WorkerMsg::ApplyUpdate(update.clone()), None))
.unwrap()
}
let mut problem = problem.write().unwrap();
problem.apply_update(&update)?;
problem.send_update(&update, WorkerUpdateSender(tx))?;
Ok(())
})
.map_err(Error::OracleError)?;
Ok(())
}
}
struct WorkerUpdateSender<S>(S);
impl<P, S> UpdateSender<P> for WorkerUpdateSender<S>
where
P: DistributedFirstOrderProblem + 'static,
P::Minorant: for<'a> Deserialize<'a>,
P::Err: for<'a> Deserialize<'a>,
S: UpdateSender<Problem<P>>,
{
fn add_variables(
&self,
bounds: Vec<(Real, Real)>,
sgext: Box<dyn SubgradientExtender<P::Minorant, P::Err>>,
) -> Result<(), UpdateSendError> {
self.0.add_variables(
bounds,
Box::new(move |fidx: usize, m: &mut P::Minorant| {
sgext.extend_subgradient(fidx, m).map_err(Error::OracleError)
}),
)
}
fn error(&self, err: P::Err) -> Result<(), UpdateSendError> {
self.0.error(Error::OracleError(err))
}
}