Kestrel-3

Artifact [21916d43fc]
Login

Artifact 21916d43fcb3ca7b2758c29ae6d5641fb6f8dadf8ae176ade049b432fc7ef6bf:


#!/usr/bin/env python
#
# Implements the CSR Unit for the KCP53000B.

import argparse

from functools import reduce

from nmigen import *
from nmigen import cli

from common_csr import (
    CSR,
    CSRSelect,
    ConstantCSR,
    InputCSR,
    MemCSR,
    UpCounterCSR,
)
from kcp53k import (
    CSR_CYCLE,
    CSR_INSTRET,
    CSR_MARCHID,
    CSR_MCAUSE,
    CSR_MCOUNTEREN,
    CSR_MCYCLE,
    CSR_MEPC,
    CSR_MHARTID,
    CSR_MHPMCOUNTER0,
    CSR_MHPMEVENT0,
    CSR_MIE,
    CSR_MIMPID,
    CSR_MINSTRET,
    CSR_MIP,
    CSR_MISA,
    CSR_MSCRATCH,
    CSR_MSTATUS,
    CSR_MTVAL,
    CSR_MTVEC,
    CSR_MVENDORID,
    MCOUNTERENB_CY,
    MCOUNTERENB_IR,
    MIEB_MEIE,
    MIEB_MSIE,
    MIEB_MTIE,
    MIPB_MEIP,
    MIPB_MSIP,
    MIPB_MTIP,
    MISA_IF,
    MISA_UF,
    MSTATUSB_MIE,
    MSTATUSB_MPIE,
    MSTATUSB_MPP,
    MSTATUSB_MPRV,
    PRIV_M,
    PRIV_U,
    XLEN_64BIT,
    get_csr_priv,
    put_mimpid_day,
    put_mimpid_month,
    put_mimpid_patch,
    put_mimpid_year,
    put_mxl,
)

MIMPID_VALUE = (
    put_mimpid_year(0x2019) |
    put_mimpid_month(0x03) |
    put_mimpid_day(0x17) |
    put_mimpid_patch(0x00)
)

MISA_VALUE = (put_mxl(XLEN_64BIT) | MISA_UF | MISA_IF)


class MStatus(CSR):
    """Implements the mstatus CSR specified by the Privilege Specification
    v1.10.
    """

    @property
    def name(self):
        return "MStatus"

    def __init__(self):
        super().__init__()

        # Expose various MSTATUS fields for easy query
        self.o_mprv = Signal()
        self.o_mpp = Signal(2)
        self.o_mpie = Signal()
        self.o_mie = Signal()

    def elaborate(self, platform):
        m = Module()

        # The CSR WLRL fields are valid as long as we're not writing
        # to the CSR, or if the MPP field is about to receive a valid
        # privilege mode.

        wlrl_valid = Signal()
        m.d.comb += wlrl_valid.eq(
            (self.i_we == 0) |
            (self.i_dat[11:13] == C(PRIV_U, 2)) |
            (self.i_dat[11:13] == C(PRIV_M, 2))
        )

        # Backing storage for the MSTATUS register.

        wpri_bits = [
            2,     # formerly HIE
            6,     # formerly HPIE
            9, 10, # formerly HPP
        ] + list(range(23, 32)) + list(range(36, 63))

        m_bits = [
            3,
            7,
            11, 12,
            17,
        ]
        
        m.submodules.mem = mem = MemCSR(
            CSR_MSTATUS,
            retain_bits=(wpri_bits + m_bits),
            valid=wlrl_valid
        )

        # Expose backing storage to the outside world.

        m.d.comb += [
            mem.i_adr.eq(self.i_adr),
            mem.i_dat.eq(self.i_dat),
            mem.i_priv.eq(self.i_priv),
            mem.i_ree.eq(self.i_ree),
            mem.i_we.eq(self.i_we),

            self.o_dat.eq(mem.o_dat),
            self.o_valid.eq(mem.o_valid),

            self.o_mprv.eq(mem.o_dat[MSTATUSB_MPRV]),
            self.o_mpp.eq(mem.o_dat[MSTATUSB_MPP:MSTATUSB_MPP+2]),
            self.o_mpie.eq(mem.o_dat[MSTATUSB_MPIE]),
            self.o_mie.eq(mem.o_dat[MSTATUSB_MIE]),
        ]
        
        return m


class CSRU(CSR):
    def __init__(self):
        super().__init__()
        self.i_hart_id = Signal(4)
        self.i_sirq = Signal()
        self.i_tirq = Signal()
        self.i_eirq = Signal()
        self.i_instret_tick = Signal()
    @property
    def name(self):
        return "CSRU"

    @property
    def port_list(self):
        return super().port_list + [
            self.i_hart_id,
            self.i_sirq, self.i_tirq, self.i_eirq,
            self.i_instret_tick,
        ]

    def elaborate(self, platform):
        m = Module()

        is_mcause_valid = Signal()

        mvendorid = ConstantCSR(CSR_MVENDORID)
        marchid = ConstantCSR(CSR_MARCHID)
        mimpid = ConstantCSR(CSR_MIMPID, value=MIMPID_VALUE)
        mhartid = InputCSR(CSR_MHARTID, self.i_hart_id.nbits)
        mstatus = MStatus()
        misa = ConstantCSR(CSR_MISA, value=MISA_VALUE)  
        mie = MemCSR(CSR_MIE, retain_bits=[
            MIEB_MSIE, MIEB_MTIE, MIEB_MEIE,
        ])
        mip = InputCSR(CSR_MIP, signal_bits=[
            MIPB_MSIP, MIPB_MTIP, MIPB_MEIP,
        ])
        mtvec = MemCSR(CSR_MTVEC, retain_bits=range(2,64))
        mcounteren = MemCSR(
            CSR_MCOUNTEREN,
            retain_bits=[MCOUNTERENB_IR, MCOUNTERENB_CY],
        )
        mscratch = MemCSR(CSR_MSCRATCH, retain_bits=range(0, 64))
        mepc = MemCSR(CSR_MEPC, retain_bits=range(0, 64))
        mcause = MemCSR(
            CSR_MCAUSE,
            retain_bits=range(0, 64),
            valid=is_mcause_valid
        )
        mtval = MemCSR(CSR_MTVAL, retain_bits=range(0, 64))
        mcycle = UpCounterCSR(CSR_MCYCLE)
        minstret = UpCounterCSR(CSR_MINSTRET)
        cycle = InputCSR(
            CSR_CYCLE, signal_width=64,
            valid=mcounteren.o_dat[MCOUNTERENB_CY],
        )
        instret = InputCSR(
            CSR_INSTRET, signal_width=64,
            valid=mcounteren.o_dat[MCOUNTERENB_IR],
        )

        csrs = [
            mvendorid, marchid, mimpid, mhartid,
            # medeleg, mideleg not implemented yet
            mstatus, misa, mie, mtvec, mcounteren,
            mscratch, mepc, mcause, mtval, mip,
            mcycle, minstret, cycle, instret,
        ]

        for i in range(3, 32):
            csrs.append(ConstantCSR(CSR_MHPMCOUNTER0 + i))
            csrs.append(ConstantCSR(CSR_MHPMEVENT0 + i))

        m.submodules += csrs

        # ALL CSRs receive a copy of the address, input data, and
        # read- and write-effect enables.

        for r in csrs:
            m.d.comb += [
                r.i_priv.eq(self.i_priv),
                r.i_adr.eq(self.i_adr),
                r.i_dat.eq(self.i_dat),
                r.i_ree.eq(self.i_ree),
                r.i_we.eq(self.i_we),
            ]

        # The CSRU's final o_valid is true if any one of the addressed
        # CSRs o_valid is true.  Put another way, our o_valid is false
        # if the i_adr signal does not address a currently supported
        # CSR, or if i_priv does not meet the minimum privilege
        # requirements for the CSR, or if the CSR detects an invalid
        # field to be stored into a write-legal field.
        
        m.d.comb += self.o_valid.eq(reduce(
            lambda x,y: x | y,
            [r.o_valid for r in csrs]
        ))

        # We need to have access to the outputs of each CSR defined.
        # However, the integer execution unit logic does not.  This
        # creates a multiplexor for all the supported CSRs, so that
        # only the currently addressed CSR's o_dat signal is routed to
        # the IXU.
        
        m.d.comb += self.o_dat.eq(reduce(
            lambda x,y: x | y,
            [r.o_dat & Repl(r.o_valid, 64) for r in csrs]
        ))

        # Route CSR-specific signals so that they have access to the
        # outside world.
        
        m.d.comb += [
            mhartid.i_signal.eq(self.i_hart_id),
        ]

        # MCAUSE is an XLEN-wide WLRL register.  I can't fathom why
        # this would be the case; it's a total waste of resources.
        # But, here we are.  We support only bits 63 and 3-0; the
        # remainder of the register needs to be 0.

        m.d.comb += [
            is_mcause_valid.eq(
                (self.i_we == 0) |
                (self.i_dat[4:63] == 0)
            )
        ]

        # MIP is a 3-bit input port reflecting the current external
        # interrupt inputs.
        #
        # MSIP is true if a "software interrupt" is set.  However,
        # this is just a memory-mapped flag, so really, this input
        # can be used for anything you want.
        #
        # MTIP is true if MTIME >= MTIMECMP; and these registers
        # are also memory-mapped and resides outside the processor
        # core.  So, this too can be used for anything you want.
        #
        # MEIP is true if a purpose-reserved external interrupt
        # is asserted.

        m.d.comb += [
            mip.i_signal[0].eq(self.i_sirq),
            mip.i_signal[1].eq(self.i_tirq),
            mip.i_signal[2].eq(self.i_eirq),
        ]

        # MCYCLE counts up every clock tick.

        m.d.comb += mcycle.i_tick.eq(1)

        # MINSTRET counts up every time an instruction completes.

        m.d.comb += minstret.i_tick.eq(self.i_instret_tick)

        # Expose CYCLE and INSTRET as read-only windows of
        # MCYCLE and MINSTRET, respectively.

        m.d.comb += cycle.i_signal.eq(mcycle.o_dat)
        m.d.comb += instret.i_signal.eq(minstret.o_dat)

        return m


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "module",
        choices=[
            "MStatus",
            "CSRU",
        ],
    )
    cli.main_parser(parser)

    args = parser.parse_args()
    dut = None
    if args.module == 'MStatus':
        dut = MStatus()
    if args.module == 'CSRU':
        dut = CSRU()

    cli.main_runner(parser, args, dut, ports=dut.port_list, name=dut.name)