RsBundle  Artifact [d369680d19]

Artifact d369680d193b5e029591fa7623fc6243fb4a80c7:

  • File src/mpi/worker.rs — part of check-in [1965a45f88] at 2023-03-30 17:16:37 on branch mpi — Remove some debug output (user: fifr size: 5268)

/*
 * Copyright (c) 2023 Frank Fischer <frank-fischer@shadow-soft.de>
 *
 * This program is free software: you can redistribute it and/or
 * modify it under the terms of the GNU General Public License as
 * published by the Free Software Foundation, either version 3 of the
 * License, or (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful, but
 * WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see  <http://www.gnu.org/licenses/>
 */

use super::msg::{EvalMsg, ResultMsg, ResultType};
use crate::problem::{FirstOrderProblem, ResultSendError, ResultSender};
use crate::{DVector, Real};

use flexbuffers;
use log::info;
use mpi::environment::Universe;
use mpi::point_to_point::{Destination, Source};
use mpi::topology::Communicator;
use serde::Serialize;
use thiserror::Error;

use std::sync::mpsc::{channel, SendError, Sender};
use std::sync::Arc;

#[derive(Debug)]
enum EvalResult<P>
where
    P: FirstOrderProblem,
{
    ObjectiveValue(Real),
    Minorant(P::Minorant),
    Error(P::Err),
    Done,
}

/// Specialized error for channels.
///
/// We cannot use `SendError<...>` directly because the user type is
/// not /// required to be `Sync`.
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum Error {
    #[error("sending evaluation result failed")]
    SendEval,
}

impl<T> From<SendError<T>> for Error {
    fn from(err: SendError<T>) -> Error {
        Error::SendEval
    }
}

struct WorkerResultSender<P: FirstOrderProblem> {
    sender: Sender<EvalResult<P>>,
}

impl<P: FirstOrderProblem> Drop for WorkerResultSender<P> {
    fn drop(&mut self) {
        self.sender.send(EvalResult::Done).unwrap();
    }
}

impl<P: FirstOrderProblem> ResultSender<P> for WorkerResultSender<P> {
    fn objective(&self, value: Real) -> Result<(), ResultSendError> {
        self.sender
            .send(EvalResult::ObjectiveValue(value))
            .map_err(|_| ResultSendError::Connection);
        Ok(())
    }

    fn minorant(&self, minorant: P::Minorant) -> Result<(), ResultSendError> {
        self.sender
            .send(EvalResult::Minorant(minorant))
            .map_err(|_| ResultSendError::Connection);
        Ok(())
    }

    fn error(&self, err: P::Err) -> Result<(), ResultSendError> {
        self.sender
            .send(EvalResult::Error(err))
            .map_err(|_| ResultSendError::Connection);
        Ok(())
    }
}

pub struct Worker<P: FirstOrderProblem + 'static> {
    universe: Universe,
    problem: P,
}

impl<P: FirstOrderProblem> Worker<P>
where
    P::Minorant: Serialize,
    P::Err: Serialize,
{
    pub fn new(universe: Universe, problem: P) -> Self {
        Worker { universe, problem }
    }

    pub fn run(&mut self) {
        let world = self.universe.world();
        let client = world.process_at_rank(0);

        loop {
            let (msg, _) = client.receive::<EvalMsg>();

            if msg.i == usize::MAX {
                break;
            }

            let mut y = vec![0.0; msg.n];
            let st = client.receive_into(&mut y);
            let y = Arc::new(DVector(y));

            let (client_tx, client_rx) = channel();

            if let Err(err) = self
                .problem
                .evaluate(msg.i, y, WorkerResultSender { sender: client_tx })
            {
                panic!("Some error")
            }

            loop {
                let cmsg = client_rx.recv().expect("channel receive error");
                match cmsg {
                    EvalResult::ObjectiveValue(objval) => {
                        client.send(&ResultMsg {
                            index: msg.i,
                            typ: ResultType::ObjectiveValue,
                            n: 1,
                        });
                        client.send(&objval);
                    }
                    EvalResult::Minorant(minorant) => {
                        let mut s = flexbuffers::FlexbufferSerializer::new();
                        minorant.serialize(&mut s).unwrap();

                        client.send(&ResultMsg {
                            index: msg.i,
                            typ: ResultType::Minorant,
                            n: s.view().len(),
                        });
                        client.send(s.view());
                    }
                    EvalResult::Error(err) => {
                        client.send(&ResultMsg {
                            index: msg.i,
                            typ: ResultType::Error,
                            n: 0,
                        });
                        //client.send(&err);
                    }
                    EvalResult::Done => {
                        client.send(&ResultMsg {
                            index: msg.i,
                            typ: ResultType::Done,
                            n: 0,
                        });
                        break;
                    }
                }
            }
        }

        info!("Terminator worker process {}", world.rank());
    }
}