RsBundle  Artifact [1691ecc716]

Artifact 1691ecc716a021e0b107edde9045026c3e9a0ddb:

  • File src/mpi/worker.rs — part of check-in [ee26318760] at 2023-04-06 15:36:36 on branch mpi — mpi::worker: send error from `evaluate` to main process (user: fifr size: 4486)

/*
 * Copyright (c) 2023 Frank Fischer <frank-fischer@shadow-soft.de>
 *
 * This program is free software: you can redistribute it and/or
 * modify it under the terms of the GNU General Public License as
 * published by the Free Software Foundation, either version 3 of the
 * License, or (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful, but
 * WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see  <http://www.gnu.org/licenses/>
 */

use super::msg::{ResultMsg, WorkerMsg};
use crate::problem::{FirstOrderProblem, ResultSendError, ResultSender};
use crate::Real;

use flexbuffers;
use log::info;
use mpi::environment::Universe;
use mpi::point_to_point::{Destination, Source};
use mpi::topology::Communicator;
use serde::{Deserialize, Serialize};

use crate::mpi::problem::DistributedFirstOrderProblem;
use std::sync::mpsc::{channel, Sender};

struct WorkerResultSender<P: FirstOrderProblem> {
    index: usize,
    sender: Sender<ResultMsg<P::Minorant, P::Err>>,
}

impl<P: FirstOrderProblem> Drop for WorkerResultSender<P> {
    fn drop(&mut self) {
        self.sender.send(ResultMsg::Done { index: self.index }).unwrap();
    }
}

impl<P: FirstOrderProblem> ResultSender<P> for WorkerResultSender<P> {
    fn objective(&self, value: Real) -> Result<(), ResultSendError> {
        self.sender
            .send(ResultMsg::ObjectiveValue {
                index: self.index,
                value,
            })
            .map_err(|_| ResultSendError::Connection)
    }

    fn minorant(&self, minorant: P::Minorant) -> Result<(), ResultSendError> {
        self.sender
            .send(ResultMsg::Minorant {
                index: self.index,
                minorant,
            })
            .map_err(|_| ResultSendError::Connection)
    }

    fn error(&self, err: P::Err) -> Result<(), ResultSendError> {
        self.sender
            .send(ResultMsg::Error {
                index: self.index,
                error: err,
            })
            .map_err(|_| ResultSendError::Connection)
    }
}

pub struct Worker<P: FirstOrderProblem + 'static> {
    universe: Universe,
    problem: P,
}

impl<P: DistributedFirstOrderProblem> Worker<P>
where
    P::Minorant: Serialize,
    P::Err: Serialize,
{
    pub fn new(universe: Universe, problem: P) -> Self {
        Worker { universe, problem }
    }

    pub fn run(&mut self) {
        let world = self.universe.world();
        let client = world.process_at_rank(0);

        loop {
            let (raw, _) = client.receive_vec::<u8>();
            let r = flexbuffers::Reader::get_root(&raw[..]).unwrap();
            let msg = WorkerMsg::deserialize(r).unwrap();

            match msg {
                WorkerMsg::Terminate => break,
                WorkerMsg::Update(update) => self
                    .problem
                    .apply_update(&update)
                    .map_err(|_| "Apply error".to_string())
                    .unwrap(),
                WorkerMsg::Evaluate { i, y } => {
                    let (client_tx, client_rx) = channel();

                    if let Err(err) = self.problem.evaluate(
                        i,
                        y,
                        WorkerResultSender {
                            index: i,
                            sender: client_tx,
                        },
                    ) {
                        let mut s = flexbuffers::FlexbufferSerializer::new();
                        ResultMsg::<P::Minorant, _>::Error { index: i, error: err }
                            .serialize(&mut s)
                            .unwrap();
                        client.send(s.view());
                    }

                    loop {
                        let rmsg = client_rx.recv().expect("channel receive error");
                        let mut s = flexbuffers::FlexbufferSerializer::new();
                        rmsg.serialize(&mut s).unwrap();
                        client.send(s.view());
                        if let ResultMsg::Done { .. } = rmsg {
                            break;
                        }
                    }
                }
            }
        }

        info!("Terminate worker process {}", world.rank());
    }
}