/*
* Copyright (c) 2023 Frank Fischer <frank-fischer@shadow-soft.de>
*
* This program is free software: you can redistribute it and/or
* modify it under the terms of the GNU General Public License as
* published by the Free Software Foundation, either version 3 of the
* License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful, but
* WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>
*/
use super::msg::{ResultMsg, ResultType, WorkerMsg};
use crate::problem::{FirstOrderProblem, ResultSendError, ResultSender};
use crate::{DVector, Real};
use flexbuffers;
use log::info;
use mpi::environment::Universe;
use mpi::point_to_point::{Destination, Source};
use mpi::topology::Communicator;
use serde::{Deserialize, Serialize};
use thiserror::Error;
use crate::mpi::problem::DistributedFirstOrderProblem;
use std::sync::mpsc::{channel, SendError, Sender};
use std::sync::Arc;
#[derive(Debug)]
enum EvalResult<P>
where
P: FirstOrderProblem,
{
ObjectiveValue(Real),
Minorant(P::Minorant),
Error(P::Err),
Done,
}
/// Specialized error for channels.
///
/// We cannot use `SendError<...>` directly because the user type is
/// not /// required to be `Sync`.
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum Error {
#[error("sending evaluation result failed")]
SendEval,
}
struct WorkerResultSender<P: FirstOrderProblem> {
sender: Sender<EvalResult<P>>,
}
impl<P: FirstOrderProblem> Drop for WorkerResultSender<P> {
fn drop(&mut self) {
self.sender.send(EvalResult::Done).unwrap();
}
}
impl<P: FirstOrderProblem> ResultSender<P> for WorkerResultSender<P> {
fn objective(&self, value: Real) -> Result<(), ResultSendError> {
self.sender
.send(EvalResult::ObjectiveValue(value))
.map_err(|_| ResultSendError::Connection)
}
fn minorant(&self, minorant: P::Minorant) -> Result<(), ResultSendError> {
self.sender
.send(EvalResult::Minorant(minorant))
.map_err(|_| ResultSendError::Connection)
}
fn error(&self, err: P::Err) -> Result<(), ResultSendError> {
self.sender
.send(EvalResult::Error(err))
.map_err(|_| ResultSendError::Connection)
}
}
pub struct Worker<P: FirstOrderProblem + 'static> {
universe: Universe,
problem: P,
}
impl<P: DistributedFirstOrderProblem> Worker<P>
where
P::Minorant: Serialize,
P::Err: Serialize,
{
pub fn new(universe: Universe, problem: P) -> Self {
Worker { universe, problem }
}
pub fn run(&mut self) {
let world = self.universe.world();
let client = world.process_at_rank(0);
loop {
let (raw, _) = client.receive_vec::<u8>();
let r = flexbuffers::Reader::get_root(&raw[..]).unwrap();
let msg = WorkerMsg::deserialize(r).unwrap();
match msg {
WorkerMsg::Terminate => break,
WorkerMsg::Update(update) => self
.problem
.apply_update(&update)
.map_err(|_| "Apply error".to_string())
.unwrap(),
WorkerMsg::Evaluate { i, y } => {
let (client_tx, client_rx) = channel();
if let Err(err) = self.problem.evaluate(i, y, WorkerResultSender { sender: client_tx }) {
panic!("Some error")
}
loop {
let cmsg = client_rx.recv().expect("channel receive error");
match cmsg {
EvalResult::ObjectiveValue(objval) => {
client.send(&ResultMsg {
index: i,
typ: ResultType::ObjectiveValue,
n: 1,
});
client.send(&objval);
}
EvalResult::Minorant(minorant) => {
let mut s = flexbuffers::FlexbufferSerializer::new();
minorant.serialize(&mut s).unwrap();
client.send(&ResultMsg {
index: i,
typ: ResultType::Minorant,
n: s.view().len(),
});
client.send(s.view());
}
EvalResult::Error(err) => {
client.send(&ResultMsg {
index: i,
typ: ResultType::Error,
n: 0,
});
//client.send(&err);
}
EvalResult::Done => {
client.send(&ResultMsg {
index: i,
typ: ResultType::Done,
n: 0,
});
break;
}
}
}
}
}
}
info!("Terminate worker process {}", world.rank());
}
}