Kestrel-3

Artifact [96ea88450c]
Login

Artifact 96ea88450c1a3fc69cc1f7f9c071dbddaff5dddd340dd1f2b25a76573c777a38:


#!/usr/bin/env python
#
# Implements the tests for the CSR Unit for the KCP53000B.

from nmigen import Elaboratable, Signal, Module, ClockDomain, C, Const
from nmigen.test.tools import FHDLTestCase
from nmigen.back.pysim import Simulator, Delay
from nmigen.hdl.ast import Past, Assert, ResetSignal

from common_csr import (
    ConstantCSR,
    InputCSR,
    MemCSR,
    CSR,
    CSRSelect,
)
from csru import CSRU, MStatus
from kcp53k import (
    CSR_CYCLE,
    CSR_CYCLE,
    CSR_INSTRET,
    CSR_INSTRET,
    CSR_MARCHID,
    CSR_MCAUSE,
    CSR_MCOUNTEREN,
    CSR_MCOUNTEREN,
    CSR_MCYCLE,
    CSR_MCYCLE,
    CSR_MEPC,
    CSR_MHARTID,
    CSR_MHPMCOUNTER0,
    CSR_MHPMEVENT0,
    CSR_MIE,
    CSR_MIMPID,
    CSR_MINSTRET,
    CSR_MINSTRET,
    CSR_MIP,
    CSR_MISA,
    CSR_MSCRATCH,
    CSR_MSTATUS,
    CSR_MTVAL,
    CSR_MTVEC,
    CSR_MVENDORID,
    MCOUNTERENF_CY,
    MCOUNTERENF_IR,
    MIEB_MEIE,
    MIEB_MSIE,
    MIEB_MTIE,
    MISA_IF,
    MISA_UF,
    MSTATUSB_MIE,
    MSTATUSB_MPIE,
    MSTATUSB_MPP,
    MSTATUSB_MPRV,
    PRIV_M,
    PRIV_U,
    XLEN_64BIT,
    get_csr_priv,
    put_mimpid_day,
    put_mimpid_month,
    put_mimpid_patch,
    put_mimpid_year,
    put_mxl,
)


class MStatusFormal(Elaboratable):
    def __init__(self):
        self.i_adr = Signal(12)
        self.i_priv = Signal(2)
        self.i_ree = Signal()
        self.i_we = Signal()
        self.i_dat = Signal(64)
        self.o_dat = Signal(64)
        self.o_valid = Signal()

    def elaborate(self, platform):
        m = Module()

        z_past_valid = Signal()

        m.d.sync += z_past_valid.eq(1)

        dut = MStatus()
        m.submodules.dut = dut
        m.d.comb += [dut.i_adr.eq(self.i_adr),
                     dut.i_priv.eq(self.i_priv),
                     dut.i_ree.eq(self.i_ree),
                     dut.i_we.eq(self.i_we),
                     dut.i_dat.eq(self.i_dat),
                     self.o_dat.eq(dut.o_dat),
                     self.o_valid.eq(dut.o_valid),
                    ]

        with m.If(self.i_adr != Const(0x300, 12)):
            m.d.comb += Assert(~self.o_valid)

        with m.If(self.i_priv != Const(0x3, 2)):
            m.d.comb += Assert(~self.o_valid)

        rst = ResetSignal()
        test = Signal(reset_less=True)
        m.d.comb += test.eq(Past(self.i_we) & Past(self.o_valid) & \
                             (~Past(rst))  & z_past_valid)
        with m.If(test):

            i_dat = Signal(64)
            m.d.comb += i_dat.eq(Past(self.i_dat))

            m.d.sync += [Assert(self.o_dat[0:2] == Const(0b00, 2)),
                         Assert(self.o_dat[2:4] == i_dat[2:4]),
                         Assert(self.o_dat[4:6] == Const(0b00, 2)),
                         Assert(self.o_dat[6:8] == i_dat[6:8]),
                         Assert(self.o_dat[8] == Const(0b0, 1)),
                         Assert(self.o_dat[9:13] == i_dat[9:13]),
                         Assert(self.o_dat[13:17] == Const(0b0000, 4)),
                         Assert(self.o_dat[17] == i_dat[17]),
                         Assert(self.o_dat[18:23] == Const(0b00000, 5)),
                         Assert(self.o_dat[23:32] == i_dat[23:32]),
                         Assert(self.o_dat[32:36] == Const(0b0000, 4)),
                         Assert(self.o_dat[36:63] == i_dat[36:63]),
                         Assert(self.o_dat[63] == Const(0b0, 1)),
                        ]

        return m


class CSRUTestCase(FHDLTestCase):

    MHARTID_VALUE = 13
    MIMPID_VALUE = (
        put_mimpid_year(0x2019) |
        put_mimpid_month(0x03) |
        put_mimpid_day(0x17) |
        put_mimpid_patch(0x00)
    )
    MISA_VALUE = (put_mxl(XLEN_64BIT) | MISA_UF | MISA_IF)

    def setup_csru(self):
        # CSRU is designed to look like and act like a single CSR.
        self.i_adr = Signal(12)
        self.i_dat = Signal(64)
        self.i_priv = Signal(2)
        self.i_ree = Signal()
        self.i_we = Signal()
        self.o_dat = Signal(64)
        self.o_valid = Signal()

        # CSRU must interact with the outside world as well, though.
        self.i_hart_id = Signal(4, reset=self.MHARTID_VALUE)
        self.i_sirq = Signal()
        self.i_tirq = Signal()
        self.i_eirq = Signal()

        ports = [
            'i_adr', 'i_dat', 'i_priv', 'i_ree', 'i_we', 'o_dat', 'o_valid',
            'i_hart_id', 'i_sirq', 'i_tirq', 'i_eirq',
        ]
        ports = [self.__getattribute__(p) for p in ports]

        self.dut = CSRU()

        self.m = Module()
        self.sync = ClockDomain()
        self.m.domains += self.sync
        self.m.submodules += self.dut

        self.m.d.comb += [
            self.dut.i_adr.eq(self.i_adr),
            self.dut.i_dat.eq(self.i_dat),
            self.dut.i_priv.eq(self.i_priv),
            self.dut.i_ree.eq(self.i_ree),
            self.dut.i_we.eq(self.i_we),
            self.o_dat.eq(self.dut.o_dat),
            self.o_valid.eq(self.dut.o_valid),

            self.dut.i_hart_id.eq(self.i_hart_id),
            self.dut.i_sirq.eq(self.i_sirq),
            self.dut.i_tirq.eq(self.i_tirq),
            self.dut.i_eirq.eq(self.i_eirq),
        ]

        return ports

    def make_check_rw_csr(self, reg, priv, wrdata, valid, rddata):
        def process():
            yield Delay(10e-9)
            self.assertEqual((yield self.o_valid), 0)
            self.assertEqual((yield self.i_adr), 0)

            yield self.i_adr.eq(reg)
            yield self.i_priv.eq(C(priv))
            yield self.i_we.eq(0)
            yield self.i_ree.eq(0)
            yield self.i_dat.eq(wrdata)
            prev_dat = (yield self.o_dat)
            yield Delay(10e-9)
            self.assertEqual((yield self.o_valid), 1)  # always 1 b/c i_we=0.
            self.assertEqual((yield self.o_dat), prev_dat)

            yield self.sync.clk.eq(1)
            yield Delay(10e-9) 
            self.assertEqual((yield self.o_valid), 1)  # always 1 b/c i_we=0.
            self.assertEqual((yield self.o_dat), prev_dat)

            yield self.sync.clk.eq(0)
            yield self.i_we.eq(1)
            yield Delay(10e-9)
            self.assertEqual((yield self.o_valid), valid)
            self.assertEqual((yield self.o_dat), prev_dat)

            yield self.sync.clk.eq(1)
            yield Delay(10e-9) 
            self.assertEqual((yield self.o_valid), valid)
            self.assertEqual((yield self.o_dat), rddata)

            yield self.sync.clk.eq(0)
            yield self.i_we.eq(1)
            yield Delay(10e-9)
            self.assertEqual((yield self.o_valid), valid)
            self.assertEqual((yield self.o_dat), rddata)
        return process

    def make_check_ro_csr(self, reg, fixed_value):
        def process():
            yield Delay(10e-9)
            self.assertEqual((yield self.o_valid), 0)
            self.assertEqual((yield self.i_adr), 0)

            yield self.i_adr.eq(reg)
            yield self.i_priv.eq(C(2))
            yield self.i_we.eq(0)
            yield self.i_ree.eq(0)
            yield Delay(10e-9)
            self.assertEqual((yield self.o_valid), 0)

            yield self.i_priv.eq(C(3))
            yield Delay(10e-9)
            self.assertEqual((yield self.o_valid), 1)
            self.assertEqual((yield self.o_dat), fixed_value)

            yield self.i_we.eq(1)
            yield self.i_dat.eq(0xDEADBEEFFEEDFACE)
            yield Delay(10e-9)
            self.assertEqual((yield self.o_valid), 1)
            self.assertEqual((yield self.o_dat), fixed_value)

            yield self.sync.clk.eq(1)
            yield Delay(10e-9)
            self.assertEqual((yield self.o_valid), 1)
            self.assertEqual((yield self.o_dat), fixed_value)
        return process

    def check_ro_csr(self, base_name, reg, value):
        # This can also be used to check InputCSRs if their
        # corresponding I/O ports are held constant.
        traces = self.setup_csru()
        with Simulator(
                self.m.elaborate(platform=None),
                vcd_file=open("test_{}.vcd".format(base_name), "w"),
                gtkw_file=open("test_{}.gtkw".format(base_name), "w"),
                traces=traces,
        ) as sim:
            sim.add_process(self.make_check_ro_csr(reg, value))
            sim.run()

    def check_rw_csr(self, base_name, reg, priv, wrdata, valid, rddata):
        traces = self.setup_csru()
        with Simulator(
                self.m.elaborate(platform=None),
                vcd_file=open("test_{}.vcd".format(base_name), "w"),
                gtkw_file=open("test_{}.gtkw".format(base_name), "w"),
                traces=traces,
        ) as sim:
            sim.add_process(self.make_check_rw_csr(
                reg, priv, wrdata, valid, rddata
            ))
            sim.run()

    def test_mvendorid(self):
        self.check_ro_csr("mvendorid", CSR_MVENDORID, 0)

    def test_marchid(self):
        self.check_ro_csr("marchid", CSR_MARCHID, 0)

    def test_mimpid(self):
        self.check_ro_csr("mimpid", CSR_MIMPID, self.MIMPID_VALUE)

    def test_mhartid(self):
        self.check_ro_csr("mhartid", CSR_MHARTID, self.MHARTID_VALUE)

    def test_misa(self):
        self.check_ro_csr("misa", CSR_MISA, self.MISA_VALUE)

    def test_mip(self):
        yield self.i_sirq.eq(0)
        yield self.i_tirq.eq(0)
        yield self.i_eirq.eq(0)
        self.check_ro_csr("mip000", CSR_MIP, 0x000)

        yield self.i_sirq.eq(1)
        yield self.i_tirq.eq(0)
        yield self.i_eirq.eq(0)
        self.check_ro_csr("mip001", CSR_MIP, 0x008)

        yield self.i_sirq.eq(0)
        yield self.i_tirq.eq(1)
        yield self.i_eirq.eq(0)
        self.check_ro_csr("mip010", CSR_MIP, 0x080)

        yield self.i_sirq.eq(0)
        yield self.i_tirq.eq(0)
        yield self.i_eirq.eq(1)
        self.check_ro_csr("mip100", CSR_MIP, 0x800)

    def test_mcounteren(self):
        self.check_rw_csr("mcounteren", CSR_MCOUNTEREN, 3, 5, 1, 5)

    def test_mscratch(self):
        self.check_rw_csr(
            "mscratch", CSR_MSCRATCH, 3,
            0xDEADBEEFFEEDFACE, 1, 0xDEADBEEFFEEDFACE
        )

    def test_mepc(self):
        self.check_rw_csr(
            "mepc", CSR_MEPC, 3,
            0xDEADBEEFFEEDFACE, 1, 0xDEADBEEFFEEDFACE
        )

    def test_mtval(self):
        self.check_rw_csr(
            "mtval", CSR_MTVAL, 3,
            0xDEADBEEFFEEDFACE, 1, 0xDEADBEEFFEEDFACE
        )

    def test_mcause(self):
        self.check_rw_csr(
            "mcause", CSR_MCAUSE, 3, 0x800000000000000F, 1, 0x800000000000000F
        )
        self.check_rw_csr(
            "mcause_wlrl", CSR_MCAUSE, 3, 0x800000000000001F, 0, 0
        )

    def test_mstatus(self):
        self.check_rw_csr(
            "mstatus", CSR_MSTATUS, 3,
            0xFFFFFFFFFFFFFFFF, 1, 0x7FFFFFF0FF821ECC
        )
        self.check_rw_csr(
            "mstatus_wlrl", CSR_MSTATUS, 3,
            0xAAAAAAAAAAAAAAAA, 0, 0
        )

    def test_mie(self):
        self.check_rw_csr(
            "mie", CSR_MIE, 3,
            0xFFFFFFFFFFFFFFFF, 1, 0x0000000000000888
        )

    def test_mtvec(self):
        self.check_rw_csr(
            "mtvec", CSR_MTVEC, 3,
            0xFFFFFFFFFFFFFFFF, 1, 0xFFFFFFFFFFFFFFFC
        )

    def test_mhpmevent(self):
        for i in range(3, 32):
            self.check_ro_csr("mhpmevent{}".format(i), CSR_MHPMEVENT0+i, 0)
            self.check_ro_csr("mhpmcounter{}".format(i), CSR_MHPMCOUNTER0+i, 0)
            
    def rd(self, priv, reg):
        yield self.i_adr.eq(reg)
        yield self.i_dat.eq(0)
        yield self.i_priv.eq(priv)
        yield self.i_we.eq(0)
        yield self.i_ree.eq(1)
        yield self.sync.clk.eq(1)
        yield Delay(10e-9)
        yield self.sync.clk.eq(0)
        yield Delay(10e-9)
        return (yield self.o_dat)

    def wr(self, priv, reg, datums):
        yield self.i_adr.eq(reg)
        yield self.i_dat.eq(datums)
        yield self.i_priv.eq(priv)
        yield self.i_we.eq(1)
        yield self.i_ree.eq(0)
        yield self.sync.clk.eq(1)
        yield Delay(10e-9)
        yield self.sync.clk.eq(0)
        yield Delay(10e-9)

    def test_mcycle_cycle(self):
        traces = self.setup_csru()
        base_name = "mcycle_cycle"
        with Simulator(
                self.m.elaborate(platform=None),
                vcd_file=open("test_{}.vcd".format(base_name), "w"),
                gtkw_file=open("test_{}.gtkw".format(base_name), "w"),
                traces=traces,
        ) as sim:
            def process():
                x0 = yield from self.rd(PRIV_M, CSR_MCYCLE)
                x1 = yield from self.rd(PRIV_M, CSR_MCYCLE)
                self.assertEqual((x1 - x0), 1)

                yield from self.rd(PRIV_U, CSR_CYCLE)
                self.assertEqual((yield self.o_valid), 0)

                yield from self.wr(PRIV_M, CSR_MCOUNTEREN, MCOUNTERENF_CY)
                x0 = yield from self.rd(PRIV_U, CSR_CYCLE)
                x1 = yield from self.rd(PRIV_U, CSR_CYCLE)
                self.assertEqual((yield self.o_valid), 1)
                self.assertEqual((x1 - x0), 1)

                # CYCLE is a read-only mirror; M-mode code needs to write
                # to MCYCLE to reset the value of the register.
                yield from self.wr(PRIV_M, CSR_CYCLE, 0)
                x2 = yield from self.rd(PRIV_U, CSR_CYCLE)
                self.assertTrue(x2 > x1)

            sim.add_process(process)
            sim.run()

    def test_mstatus_formally(self):
        self.assertFormal(MStatusFormal(), mode="bmc", depth=100)
        self.assertFormal(MStatusFormal(), mode="prove", depth=100)