from nmigen import Elaboratable, Signal, Module, Const
from nmigen.hdl.ast import Assert, Assume, Past, ResetSignal
from nmigen.test.tools import FHDLTestCase
from common_csr import (
CSRSelect, ConstantCSR, InputCSR, MemCSR, UpCounterCSR
)
class CSRSelectFormal(Elaboratable):
def __init__(self):
self.i_adr = Signal(12)
self.i_priv = Signal(2)
self.i_valid = Signal()
self.o_valid = Signal()
def elaborate(self, platform):
m = Module()
i_priv_valid = Signal()
m.d.comb += i_priv_valid.eq(self.i_priv[0] | self.i_priv[1])
dut = CSRSelect(0x123)
m.submodules.dut = dut
m.d.comb += [
dut.i_adr.eq(self.i_adr),
dut.i_priv.eq(self.i_priv),
dut.i_valid.eq(self.i_valid),
self.o_valid.eq(dut.o_valid),
]
with m.If(~self.i_valid):
m.d.comb += Assert(~self.o_valid);
with m.If(self.i_adr != Const(0x123, 12)):
m.d.comb += Assert(~self.o_valid);
with m.If(self.i_priv == Const(0, 2)):
m.d.comb += Assert(~self.o_valid);
with m.If(
(self.i_adr == Const(0x123, 12)) &
i_priv_valid & self.i_valid
):
m.d.comb += Assert(self.o_valid)
return m
class ConstantCSRFormal(Elaboratable):
def __init__(self):
self.i_adr = Signal(12)
self.i_priv = Signal(2)
self.o_dat = Signal(64)
self.o_valid = Signal()
def elaborate(self, platform):
m = Module()
i_priv_valid = Signal()
m.d.comb += [i_priv_valid.eq(self.i_priv[0] | self.i_priv[1])]
dut = ConstantCSR(0x123, 0xDEADBEEFFEEDFACE)
m.submodules.dut = dut
m.d.comb += [
dut.i_adr.eq(self.i_adr),
dut.i_priv.eq(self.i_priv),
self.o_dat.eq(dut.o_dat),
self.o_valid.eq(dut.o_valid),
]
with m.If(self.i_adr != Const(0x123, 12)):
m.d.comb += Assert(~self.o_valid)
with m.If(self.i_priv == Const(0, 2)):
m.d.comb += Assert(~self.o_valid)
with m.If((self.i_adr == Const(0x123, 12)) & i_priv_valid):
m.d.comb += Assert(self.o_valid)
m.d.comb += Assert(self.o_dat == Const(0xDEADBEEFFEEDFACE, 64))
return m
class InputCSRFormal(Elaboratable):
def __init__(self):
self.i_adr = Signal(12)
self.i_priv = Signal(2)
self.o_dat = Signal(64)
self.o_valid = Signal()
self.i_signal = Signal(52)
def elaborate(self, platform):
m = Module()
i_priv_valid = Signal()
m.d.comb += i_priv_valid.eq(self.i_priv[0] | self.i_priv[1])
dut = InputCSR(0x123, signal_width=52)
m.submodules.dut = dut
m.d.comb += [
dut.i_adr.eq(self.i_adr),
dut.i_priv.eq(self.i_priv),
dut.i_signal.eq(self.i_signal),
self.o_dat.eq(dut.o_dat),
self.o_valid.eq(dut.o_valid),
]
with m.If(self.i_adr != Const(0x123, 12)):
m.d.comb += Assert(~self.o_valid)
with m.If(self.i_priv == Const(0, 2)):
m.d.comb += Assert(~self.o_valid)
with m.If((self.i_adr == Const(0x123, 12)) & i_priv_valid):
m.d.comb += Assert(self.o_valid)
m.d.comb += Assert(self.o_dat == self.i_signal)
return m
class MemCSRFormal(Elaboratable):
def __init__(self):
self.i_adr = Signal(12)
self.i_dat = Signal(64)
self.i_priv = Signal(2)
self.i_we = Signal()
self.o_dat = Signal(64)
self.o_valid = Signal()
def elaborate(self, platform):
m = Module()
i_priv_valid = Signal()
m.d.comb += i_priv_valid.eq(self.i_priv[0] | self.i_priv[1])
z_past_valid = Signal(reset=0)
m.d.sync += z_past_valid.eq(1)
rst = ResetSignal()
with m.If(~z_past_valid):
Assume(rst)
dut = MemCSR(0x123, range(2, 64))
m.submodules.dut = dut
m.d.comb += [
dut.i_adr.eq(self.i_adr),
dut.i_dat.eq(self.i_dat),
dut.i_priv.eq(self.i_priv),
dut.i_we.eq(self.i_we),
self.o_dat.eq(dut.o_dat),
self.o_valid.eq(dut.o_valid),
]
with m.If(self.i_adr != Const(0x123, 12)):
m.d.comb += Assert(~self.o_valid)
with m.If(self.i_priv == Const(0, 2)):
m.d.comb += Assert(~self.o_valid)
with m.If((self.i_adr == Const(0x123, 12)) & i_priv_valid):
m.d.comb += Assert(self.o_valid)
m.d.comb += Assert(self.o_dat[0:1] == Const(0, 2))
# I wish I could just reference Past(self.i_dat) in the Assert
# below, but alas, nmigen will not accept that construct. This
# works around this limitation. -saf2
past_i_dat = Signal(64, reset_less=True)
m.d.comb += past_i_dat.eq(Past(self.i_dat))
with m.If(
z_past_valid &
~Past(rst) &
Past(self.i_we) &
Past(self.o_valid)
):
m.d.sync += Assert(self.o_dat[2:] == past_i_dat[2:])
return m
class UpCounterCSRFormal(Elaboratable):
def __init__(self):
self.i_adr = Signal(12)
self.i_dat = Signal(64)
self.i_priv = Signal(2)
self.i_we = Signal()
self.o_dat = Signal(64)
self.o_valid = Signal()
self.i_tick = Signal()
def elaborate(self, platform):
m = Module()
i_priv_valid = Signal()
m.d.comb += i_priv_valid.eq(self.i_priv[0] | self.i_priv[1])
z_past_valid = Signal(reset=0)
m.d.sync += z_past_valid.eq(1)
rst = ResetSignal()
with m.If(~z_past_valid):
Assume(rst)
dut = UpCounterCSR(0x123, counter_width=16)
m.submodules.dut = dut
m.d.comb += [
dut.i_adr.eq(self.i_adr),
dut.i_dat.eq(self.i_dat),
dut.i_priv.eq(self.i_priv),
dut.i_we.eq(self.i_we),
dut.i_tick.eq(self.i_tick),
self.o_dat.eq(dut.o_dat),
self.o_valid.eq(dut.o_valid),
]
new_o_dat = Signal(16, reset_less=True)
m.d.comb += new_o_dat.eq(self.o_dat[0:16])
with m.If(self.i_adr != Const(0x123, 12)):
m.d.comb += Assert(~self.o_valid)
with m.If(self.i_priv == Const(0, 2)):
m.d.comb += Assert(~self.o_valid)
with m.If((self.i_adr == Const(0x123, 12)) & i_priv_valid):
m.d.comb += Assert(self.o_valid)
with m.If(z_past_valid & Past(rst)):
m.d.sync += Assert(new_o_dat == Const(0, 16))
old_i_dat = Signal(16, reset_less=True)
m.d.comb += old_i_dat.eq(Past(self.i_dat)[0:16])
with m.If(
z_past_valid &
~Past(rst) &
Past(self.i_we) &
Past(self.o_valid)
):
m.d.sync += Assert(new_o_dat == old_i_dat)
old_o_dat = Signal(16, reset_less=True)
m.d.comb += old_o_dat.eq(Past(new_o_dat))
with m.If(
z_past_valid &
~Past(rst) &
Past(self.i_we) &
~Past(self.o_valid) &
~Past(self.i_tick)
):
m.d.sync += Assert(new_o_dat == old_o_dat)
with m.If(
z_past_valid &
~Past(rst) &
~Past(self.i_we) &
Past(self.i_tick)
):
m.d.sync += Assert(new_o_dat == ((old_o_dat + 1) & 0xFFFF))
with m.If(
z_past_valid &
~Past(rst) &
~Past(self.i_we) &
~Past(self.i_tick)
):
m.d.sync += Assert(new_o_dat == old_o_dat)
return m
class CommonCSRTest(FHDLTestCase):
def test_CSRSelect(self):
self.assertFormal(CSRSelectFormal(), mode="bmc", depth=100)
self.assertFormal(CSRSelectFormal(), mode="prove", depth=100)
def test_ConstantCSR(self):
self.assertFormal(ConstantCSRFormal(), mode="bmc", depth=100)
self.assertFormal(ConstantCSRFormal(), mode="prove", depth=100)
def test_InputCSR(self):
self.assertFormal(InputCSRFormal(), mode="bmc", depth=100)
self.assertFormal(InputCSRFormal(), mode="prove", depth=100)
def test_MemCSR(self):
self.assertFormal(MemCSRFormal(), mode="bmc", depth=100)
self.assertFormal(MemCSRFormal(), mode="prove", depth=100)
def test_UpCounterCSR(self):
self.assertFormal(UpCounterCSRFormal(), mode="bmc", depth=100)
self.assertFormal(UpCounterCSRFormal(), mode="prove", depth=100)