/*
* Copyright (c) 2023 Frank Fischer <frank-fischer@shadow-soft.de>
*
* This program is free software: you can redistribute it and/or
* modify it under the terms of the GNU General Public License as
* published by the Free Software Foundation, either version 3 of the
* License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful, but
* WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>
*/
use super::msg::{ResultMsg, WorkerMsg};
use crate::problem::{FirstOrderProblem, ResultSendError, ResultSender};
use crate::Real;
use flexbuffers;
use log::info;
use mpi::environment::Universe;
use mpi::point_to_point::{Destination, Source};
use mpi::topology::Communicator;
use serde::{Deserialize, Serialize};
use crate::mpi::problem::DistributedFirstOrderProblem;
use std::sync::mpsc::{channel, Sender};
struct WorkerResultSender<P: FirstOrderProblem> {
index: usize,
sender: Sender<ResultMsg<P::Minorant, P::Err>>,
}
impl<P: FirstOrderProblem> Drop for WorkerResultSender<P> {
fn drop(&mut self) {
self.sender.send(ResultMsg::Done { index: self.index }).unwrap();
}
}
impl<P: FirstOrderProblem> ResultSender<P> for WorkerResultSender<P> {
fn objective(&self, value: Real) -> Result<(), ResultSendError> {
self.sender
.send(ResultMsg::ObjectiveValue {
index: self.index,
value,
})
.map_err(|_| ResultSendError::Connection)
}
fn minorant(&self, minorant: P::Minorant) -> Result<(), ResultSendError> {
self.sender
.send(ResultMsg::Minorant {
index: self.index,
minorant,
})
.map_err(|_| ResultSendError::Connection)
}
fn error(&self, err: P::Err) -> Result<(), ResultSendError> {
self.sender
.send(ResultMsg::Error {
index: self.index,
error: err,
})
.map_err(|_| ResultSendError::Connection)
}
}
pub struct Worker<P: FirstOrderProblem + 'static> {
universe: Universe,
problem: P,
}
impl<P: DistributedFirstOrderProblem> Worker<P>
where
P::Minorant: Serialize,
P::Err: Serialize,
{
pub fn new(universe: Universe, problem: P) -> Self {
Worker { universe, problem }
}
pub fn run(&mut self) {
let world = self.universe.world();
let client = world.process_at_rank(0);
loop {
let (raw, _) = client.receive_vec::<u8>();
let r = flexbuffers::Reader::get_root(&raw[..]).unwrap();
let msg = WorkerMsg::deserialize(r).unwrap();
match msg {
WorkerMsg::Terminate => break,
WorkerMsg::Update(update) => self
.problem
.apply_update(&update)
.map_err(|_| "Apply error".to_string())
.unwrap(),
WorkerMsg::Evaluate { i, y } => {
let (client_tx, client_rx) = channel();
if let Err(err) = self.problem.evaluate(
i,
y,
WorkerResultSender {
index: i,
sender: client_tx,
},
) {
let mut s = flexbuffers::FlexbufferSerializer::new();
ResultMsg::<P::Minorant, _>::Error { index: i, error: err }
.serialize(&mut s)
.unwrap();
client.send(s.view());
}
loop {
let rmsg = client_rx.recv().expect("channel receive error");
let mut s = flexbuffers::FlexbufferSerializer::new();
rmsg.serialize(&mut s).unwrap();
client.send(s.view());
if let ResultMsg::Done { .. } = rmsg {
break;
}
}
}
}
}
info!("Terminate worker process {}", world.rank());
}
}