Kestrel-3

Artifact [217f286244]
Login

Artifact 217f2862441726f55a61ea25ed1a21c11f6b1c03400232b26ab88d4e41a88c73:


from nmigen import Elaboratable, Signal, Module, Const
from nmigen.hdl.ast import Assert, Assume, Past, ResetSignal
from nmigen.test.tools import FHDLTestCase

from common_csr import (
    CSRSelect, ConstantCSR, InputCSR, MemCSR, UpCounterCSR
)

class CSRSelectFormal(Elaboratable):
    def __init__(self):
        self.i_adr = Signal(12)
        self.i_priv = Signal(2)
        self.i_valid = Signal()
        self.o_valid = Signal()

    def elaborate(self, platform):
        m = Module()

        i_priv_valid = Signal()
        m.d.comb += i_priv_valid.eq(self.i_priv[0] | self.i_priv[1])

        dut = CSRSelect(0x123)
        m.submodules.dut = dut
        m.d.comb += [
            dut.i_adr.eq(self.i_adr),
            dut.i_priv.eq(self.i_priv),
            dut.i_valid.eq(self.i_valid),
            self.o_valid.eq(dut.o_valid),
        ]

        with m.If(~self.i_valid):
            m.d.comb += Assert(~self.o_valid);

        with m.If(self.i_adr != Const(0x123, 12)):
            m.d.comb += Assert(~self.o_valid);

        with m.If(self.i_priv == Const(0, 2)):
            m.d.comb += Assert(~self.o_valid);

        with m.If(
            (self.i_adr == Const(0x123, 12)) &
            i_priv_valid & self.i_valid
        ):
            m.d.comb += Assert(self.o_valid)

        return m


class ConstantCSRFormal(Elaboratable):
    def __init__(self):
        self.i_adr = Signal(12)
        self.i_priv = Signal(2)
        self.o_dat = Signal(64)
        self.o_valid = Signal()

    def elaborate(self, platform):
        m = Module()

        i_priv_valid = Signal()
        m.d.comb += [i_priv_valid.eq(self.i_priv[0] | self.i_priv[1])]

        dut = ConstantCSR(0x123, 0xDEADBEEFFEEDFACE)
        m.submodules.dut = dut
        m.d.comb += [
            dut.i_adr.eq(self.i_adr),
            dut.i_priv.eq(self.i_priv),
            self.o_dat.eq(dut.o_dat),
            self.o_valid.eq(dut.o_valid),
        ]

        with m.If(self.i_adr != Const(0x123, 12)):
            m.d.comb += Assert(~self.o_valid)

        with m.If(self.i_priv == Const(0, 2)):
            m.d.comb += Assert(~self.o_valid)

        with m.If((self.i_adr == Const(0x123, 12)) & i_priv_valid):
            m.d.comb += Assert(self.o_valid)

        m.d.comb += Assert(self.o_dat == Const(0xDEADBEEFFEEDFACE, 64))
            
        return m


class InputCSRFormal(Elaboratable):
    def __init__(self):
        self.i_adr = Signal(12)
        self.i_priv = Signal(2)
        self.o_dat = Signal(64)
        self.o_valid = Signal()
        self.i_signal = Signal(52)

    def elaborate(self, platform):
        m = Module()

        i_priv_valid = Signal()
        m.d.comb += i_priv_valid.eq(self.i_priv[0] | self.i_priv[1])

        dut = InputCSR(0x123, signal_width=52)
        m.submodules.dut = dut
        m.d.comb += [
            dut.i_adr.eq(self.i_adr),
            dut.i_priv.eq(self.i_priv),
            dut.i_signal.eq(self.i_signal),
            self.o_dat.eq(dut.o_dat),
            self.o_valid.eq(dut.o_valid),
        ]

        with m.If(self.i_adr != Const(0x123, 12)):
            m.d.comb += Assert(~self.o_valid)

        with m.If(self.i_priv == Const(0, 2)):
            m.d.comb += Assert(~self.o_valid)

        with m.If((self.i_adr == Const(0x123, 12)) & i_priv_valid):
            m.d.comb += Assert(self.o_valid)

        m.d.comb += Assert(self.o_dat == self.i_signal)

        return m


class MemCSRFormal(Elaboratable):
    def __init__(self):
        self.i_adr = Signal(12)
        self.i_dat = Signal(64)
        self.i_priv = Signal(2)
        self.i_we = Signal()
        self.o_dat = Signal(64)
        self.o_valid = Signal()

    def elaborate(self, platform):
        m = Module()

        i_priv_valid = Signal()
        m.d.comb += i_priv_valid.eq(self.i_priv[0] | self.i_priv[1])

        z_past_valid = Signal(reset=0)
        m.d.sync += z_past_valid.eq(1)

        rst = ResetSignal()
        with m.If(~z_past_valid):
            Assume(rst)

        dut = MemCSR(0x123, range(2, 64))
        m.submodules.dut = dut
        m.d.comb += [
            dut.i_adr.eq(self.i_adr),
            dut.i_dat.eq(self.i_dat),
            dut.i_priv.eq(self.i_priv),
            dut.i_we.eq(self.i_we),
            self.o_dat.eq(dut.o_dat),
            self.o_valid.eq(dut.o_valid),
        ]

        with m.If(self.i_adr != Const(0x123, 12)):
            m.d.comb += Assert(~self.o_valid)

        with m.If(self.i_priv == Const(0, 2)):
            m.d.comb += Assert(~self.o_valid)

        with m.If((self.i_adr == Const(0x123, 12)) & i_priv_valid):
            m.d.comb += Assert(self.o_valid)

        m.d.comb += Assert(self.o_dat[0:1] == Const(0, 2))

        # I wish I could just reference Past(self.i_dat) in the Assert
        # below, but alas, nmigen will not accept that construct.  This
        # works around this limitation. -saf2
        past_i_dat = Signal(64, reset_less=True)
        m.d.comb += past_i_dat.eq(Past(self.i_dat))
        with m.If(
            z_past_valid &
            ~Past(rst) &
            Past(self.i_we) &
            Past(self.o_valid)
        ):
            m.d.sync += Assert(self.o_dat[2:] == past_i_dat[2:])

        return m


class UpCounterCSRFormal(Elaboratable):
    def __init__(self):
        self.i_adr = Signal(12)
        self.i_dat = Signal(64)
        self.i_priv = Signal(2)
        self.i_we = Signal()
        self.o_dat = Signal(64)
        self.o_valid = Signal()
        self.i_tick = Signal()

    def elaborate(self, platform):
        m = Module()

        i_priv_valid = Signal()
        m.d.comb += i_priv_valid.eq(self.i_priv[0] | self.i_priv[1])

        z_past_valid = Signal(reset=0)
        m.d.sync += z_past_valid.eq(1)

        rst = ResetSignal()
        with m.If(~z_past_valid):
            Assume(rst)

        dut = UpCounterCSR(0x123, counter_width=16)
        m.submodules.dut = dut
        m.d.comb += [
            dut.i_adr.eq(self.i_adr),
            dut.i_dat.eq(self.i_dat),
            dut.i_priv.eq(self.i_priv),
            dut.i_we.eq(self.i_we),
            dut.i_tick.eq(self.i_tick),
            self.o_dat.eq(dut.o_dat),
            self.o_valid.eq(dut.o_valid),
        ]

        new_o_dat = Signal(16, reset_less=True)
        m.d.comb += new_o_dat.eq(self.o_dat[0:16])

        with m.If(self.i_adr != Const(0x123, 12)):
            m.d.comb += Assert(~self.o_valid)

        with m.If(self.i_priv == Const(0, 2)):
            m.d.comb += Assert(~self.o_valid)

        with m.If((self.i_adr == Const(0x123, 12)) & i_priv_valid):
            m.d.comb += Assert(self.o_valid)

        with m.If(z_past_valid & Past(rst)):
            m.d.sync += Assert(new_o_dat == Const(0, 16))

        old_i_dat = Signal(16, reset_less=True)
        m.d.comb += old_i_dat.eq(Past(self.i_dat)[0:16])

        with m.If(
            z_past_valid &
            ~Past(rst) & 
            Past(self.i_we) & 
            Past(self.o_valid)
        ):
            m.d.sync += Assert(new_o_dat == old_i_dat)

        old_o_dat = Signal(16, reset_less=True)
        m.d.comb += old_o_dat.eq(Past(new_o_dat))

        with m.If(
            z_past_valid &
            ~Past(rst) &
            Past(self.i_we) &
            ~Past(self.o_valid) &
            ~Past(self.i_tick)
        ):
            m.d.sync += Assert(new_o_dat == old_o_dat)

        with m.If(
            z_past_valid &
            ~Past(rst) &
            ~Past(self.i_we) &
            Past(self.i_tick)
        ):
            m.d.sync += Assert(new_o_dat == ((old_o_dat + 1) & 0xFFFF))

        with m.If(
            z_past_valid &
            ~Past(rst) &
            ~Past(self.i_we) &
            ~Past(self.i_tick)
        ):
            m.d.sync += Assert(new_o_dat == old_o_dat)

        return m


class CommonCSRTest(FHDLTestCase):
    def test_CSRSelect(self):
        self.assertFormal(CSRSelectFormal(), mode="bmc", depth=100)
        self.assertFormal(CSRSelectFormal(), mode="prove", depth=100)

    def test_ConstantCSR(self):
        self.assertFormal(ConstantCSRFormal(), mode="bmc", depth=100)
        self.assertFormal(ConstantCSRFormal(), mode="prove", depth=100)

    def test_InputCSR(self):
        self.assertFormal(InputCSRFormal(), mode="bmc", depth=100)
        self.assertFormal(InputCSRFormal(), mode="prove", depth=100)

    def test_MemCSR(self):
        self.assertFormal(MemCSRFormal(), mode="bmc", depth=100)
        self.assertFormal(MemCSRFormal(), mode="prove", depth=100)

    def test_UpCounterCSR(self):
        self.assertFormal(UpCounterCSRFormal(), mode="bmc", depth=100)
        self.assertFormal(UpCounterCSRFormal(), mode="prove", depth=100)