Kestrel-3

Artifact [5c47062092]
Login

Artifact 5c470620924b19d7fb25a818c1889cd9bd10e04d506eb085c0c30cb1e73ce66f:


#!/usr/bin/env python3
#
# CSR (Configuration and Status Register) support code.


import argparse

from nmigen import (
    cli,
    Elaboratable, Signal, Module, Const, C, Cat,
)

from kcp53k import get_csr_priv


class CSR(Elaboratable):
    """Generic CSR register interface, designed to plug into the KCP53000
    C port.  Refer to the KCP53000 or KCP53000B documentation for
    details on how this port works.

    This base class implements the bare interface for a CSR.  If you
    also want to re-use address decoding, see also CSRSelect.
    """

    def __init__(self):
        self.i_adr = Signal(12)
        self.i_dat = Signal(64)
        self.i_priv = Signal(2)
        self.i_ree = Signal()
        self.i_we = Signal()
        self.o_dat = Signal(64)
        self.o_valid = Signal()

    @property
    def port_list(self):
        return [
            self.i_adr,
            self.i_dat,
            self.i_priv,
            self.i_ree,
            self.i_we,
            self.o_dat,
            self.o_valid,
        ]


class CSRSelect(Elaboratable):
    """Used to implement a CSR at a given address.

    To use this class, you need to *compose* it into your CSR class as
    a submodule.  Do not inherit from this class.  Then, in your CSR's
    elaborate method, you need to invoke the selector's elaborate
    method so as to compile the address selection logic into your
    CSR's module.

    Although not strictly a part of the RISC-V privilege
    specification, some bits of the CSR's address are taken to be the
    privilege level the processor MUST be at in order to be
    successfully accessed.  This class implements the privilege check
    logic for you, so that you need not have to worry about it.
    However, be sure to conform to the established CSR addressing
    guidelines!

    Example:

    class MyCSR(CSR):
        def __init__(self, addr=CSR_DEFAULT_ADDR):
            super().__init__()
            self.o_d = Signal()
            self.o_q = Signal()
            self.addr = addr

        def elaborate(self, platform):
            m = Module()
            m.submodules += selector = CSRSelect(addr)
            selector.connect_to(self, m.d.comb)

            # ...etc...
            return m.lower(platform)

    """

    def __init__(self, csr_addr):
        """Constructs a CSR address decoder.

        :param csr_addr: The address of the CSR.  Suggestion: avoid
            hard-coded numbers like 0x300.  Choose instead symbolic
            names, such as CSR_MSTATUS.  NOTE: Only the lower 12 bits
            are valid.
        """
        self.csr_addr = Const(csr_addr, 12)
        self.priv = Const(get_csr_priv(csr_addr), 2)

        self.i_adr = Signal(12)
        self.i_priv = Signal(2)
        self.i_valid = Signal()
        self.o_valid = Signal()

    @property
    def port_list(self):
        return (
            self.i_adr,
            self.i_priv,
            self.i_valid,
            self.o_valid,
        )

    @property
    def name(self):
        return "CSRSelect"

    def elaborate(self, platform):
        """Reifies the address decode logic into the module of the CSR.
        """
        m = Module()
        m.d.comb += self.o_valid.eq(
            (self.i_adr == self.csr_addr) &
            (self.i_priv >= self.priv) &
            self.i_valid
        )
        return m

    def connect_to(self, outer, domain, valid=None):
        """Wires the CSR address decoder up to the interface of the outer module it's used with.

        :param outer: A object that conforms to nmigen's module
            generator protocol.
        
        :param domain: The combinatorial domain to attach the wiring
            to.

        :param valid: If not None, a Signal(1) which ultimately gates
            the o_valid signal.  If None, it's assumed to be hardwired
            to 1.  If left unspecified, defaults to None.
        """

        if valid is None:
            valid = C(1)

        domain += [
            self.i_adr.eq(outer.i_adr),
            self.i_priv.eq(outer.i_priv),
            self.i_valid.eq(valid),
            outer.o_valid.eq(self.o_valid),
        ]


class ConstantCSR(CSR):
    """Used to create a CSR which is hardwired to a specific numeric
    value.
    """

    @property
    def name(self):
        return "ConstantCSR"

    def __init__(self, csr_addr, value=0):
        super().__init__()
        self.addr = csr_addr
        self.value = Const(value, 64)

    def elaborate(self, platform):
        m = Module()
        selector = CSRSelect(self.addr)
        m.submodules += selector
        selector.connect_to(self, m.d.comb)

        m.d.comb += self.o_dat.eq(self.value)

        return m


class InputCSR(CSR):
    """Used to expose a set of general purpose inputs as a read-only
    CSR.  If the InputCSR is specified with just a signal_width
    keyword argument, then the low-order bits of the CSR will map
    directly to the input signal, bit for bit.  Higher-order bits
    will be hardwired zero.

    However, if signal_bits is provided with a list of ordinal bit
    positions, then each bit of the signal input will be distributed
    to the specified CSR bit positions (in ascending order).  So,
    if signal_bits=[5,4,3,31,30], then the input signal will have
    a width of 5 bits, and signal[0] will map to o_dat[3], signal[1]
    to o_dat[4], signal[2] to o_dat[5], signal[3] to o_dat[30], and
    signal[4] to o_dat[31].  All unmapped bits will be hardwired zero.

    Specifying both signal_width and signal_bits at the same time is
    an error.  Specifying neither is also an error.
    """

    @property
    def name(self):
        return "InputCSR"

    def __init__(
        self, csr_addr,
        signal_width=None, signal_bits=None, valid=None
    ):
        super().__init__()

        # If the user specified just a raw signal width attribute,
        # then convert this into the corresponding signal bits list.

        if signal_width is None and signal_bits is None:
            raise Exception(
                "At least one of signal_width and signal_bits "
                "must be specified."
            )

        if signal_width and signal_bits:
            raise Exception(
                "Only one of signal_width or signal_bits "
                "must be specified."
            )

        if signal_width and not signal_bits:
            signal_bits = range(0, signal_width)

        self.addr = csr_addr
        self.sigbits = signal_bits

        self.i_signal = Signal(len(self.sigbits))
        self.valid = valid

    def elaborate(self, platform):
        m = Module()
        selector = CSRSelect(self.addr)
        m.submodules += selector
        selector.connect_to(self, m.d.comb, valid=self.valid)

        j = 0
        for i in range(self.o_dat.nbits):
            if i in self.sigbits:
                m.d.comb += self.o_dat[i].eq(self.i_signal[j])
                j = j + 1
            else:
                m.d.comb += self.o_dat[i].eq(0)

        return m


class MemCSR(CSR):
    """Some CSRs, like MSCRATCH and MTVEC, just hold configuration
    data which holds little to no particular meaning to the processor.
    MemCSRs are ideal for describing such registers.
    """

    @property
    def name(self):
        return "MemCSR"

    def __init__(self, csr_addr, retain_bits, valid=None):
        """Creates a MemCSR instance.  The bits to retain in memory (DFFs)
        are notated as a list of ordinal positions.  For example, the MIE
        register could be specified with the list [3, 7, 11], since
        only bits 3, 7, and 11 are implemented.  Contrast against
        MTVEC, which implements bits range(2,64).  If auto-vectoring
        becomes available, the retain_bits list would be [1] +
        list(range(2,64)).

        When reading out the contents of the CSR, unsupported bits are
        hardwired as 0.

        If you need input validation for the CSR, the o_valid signal
        can be gated by the signal specified by the valid keyword
        argument.  If not specified, no gating occurs (and is
        equivalent to a hardwired 1).
        """

        super().__init__()
        self.addr = csr_addr
        self.retained_bits = retain_bits
        self.bits = Signal(len(self.retained_bits))
        self.valid = valid

    def elaborate(self, platform):
        m = Module()
        selector = CSRSelect(self.addr)
        m.submodules += selector
        selector.connect_to(self, m.d.comb, valid=self.valid)

        j = 0
        for i in range(0, self.o_dat.nbits):
            if i in self.retained_bits:
                m.d.comb += self.o_dat[i].eq(self.bits[j])
                j = j + 1
            else:
                m.d.comb += self.o_dat[i].eq(C(0,1))
                
        with m.If(self.o_valid & self.i_we):
            j = 0
            for i in range(0, self.o_dat.nbits):
                if i in self.retained_bits:
                    m.d.sync += self.bits[j].eq(self.i_dat[i])
                    j = j + 1

        return m


class UpCounterCSR(CSR):
    """Some CSRs, like MCYCLE or MINSTRET, are up-counters.  Like MemCSRs,
    the processor may alter their contents at any time.  Unlike MemCSRs,
    however, they increment by one if their count-enable input is asserted
    for a clock cycle.
    """

    @property
    def name(self):
        return "UpCounterCSR"

    @property
    def port_list(self):
        return super().port_list + [self.i_tick]

    def __init__(self, csr_addr, counter_width=None):
        """Creates an UpCounterCSR instance.  The width of the counter is
        specified by the counter_width keyword argument.  When reading
        out the contents of the CSR, unsupported bits are hardwired as
        0.

        If left unspecified, the default counter width is XLEN bits.
        """

        super().__init__()

        if not counter_width:
            counter_width = self.o_dat.nbits

        self.addr = csr_addr
        self.counter = Signal(counter_width)
        self.padding = C(0, self.o_dat.nbits - counter_width)

        self.i_tick = Signal()

    def elaborate(self, platform):
        m = Module()
        selector = CSRSelect(self.addr)
        m.submodules += selector
        selector.connect_to(self, m.d.comb)

        # We avoid the use of m.Elif() here because it
        # results in Verilog that triggers Verilator's
        # linter for overlapping cases.

        write_ctr = Signal()
        m.d.comb += write_ctr.eq(self.i_we & self.o_valid)

        with m.If(write_ctr):
            m.d.sync += self.counter.eq(self.i_dat[0:self.counter.nbits])
        with m.If((~write_ctr) & self.i_tick):
            m.d.sync += self.counter.eq(self.counter + 1)

        m.d.comb += self.o_dat.eq(Cat(self.counter, self.padding))

        return m


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "module",
        choices=[
            "CSRSelect",
            "ConstantCSR",
            "InputCSR",
            "MemCSR",
            "UpCounterCSR",
        ]
    )
    cli.main_parser(parser)

    args = parser.parse_args()
    dut = None
    if args.module == "CSRSelect":
        dut = CSRSelect(0x123)
    elif args.module == "ConstantCSR":
        dut = ConstantCSR(0x123, value=0x456)
    elif args.module == "InputCSR":
        dut = InputCSR(0x123, 52)
    elif args.module == "MemCSR":
        dut = MemCSR(0x123, retain_bits=range(2, 64))
    elif args.module == "UpCounterCSR":
        dut = UpCounterCSR(0x123)

    cli.main_runner(
        parser,
        args,
        dut,
        ports=dut.port_list,
        name=dut.name,
    )