From 907081d255625150c6f5c88ae4a5f74142fd9f42 Mon Sep 17 00:00:00 2001
From: Alex Forencich <alex@alexforencich.com>
Date: Sun, 28 Nov 2021 23:09:10 -0800
Subject: [PATCH] Add support to demux for routing by tdest

---
 rtl/axis_demux.v                 | 33 +++++++++++++++++++++++++++++---
 rtl/axis_demux_wrap.py           |  7 +++++--
 tb/axis_demux/Makefile           |  3 +++
 tb/axis_demux/test_axis_demux.py | 16 ++++++++++++----
 4 files changed, 50 insertions(+), 9 deletions(-)

diff --git a/rtl/axis_demux.v b/rtl/axis_demux.v
index c49fe425..7ddff4b3 100644
--- a/rtl/axis_demux.v
+++ b/rtl/axis_demux.v
@@ -54,7 +54,9 @@ module axis_demux #
     // Propagate tuser signal
     parameter USER_ENABLE = 1,
     // tuser signal width
-    parameter USER_WIDTH = 1
+    parameter USER_WIDTH = 1,
+    // route via tdest
+    parameter TDEST_ROUTE = 0
 )
 (
     input  wire                             clk,
@@ -94,6 +96,21 @@ module axis_demux #
 
 parameter CL_M_COUNT = $clog2(M_COUNT);
 
+// check configuration
+initial begin
+    if (TDEST_ROUTE) begin
+        if (!DEST_ENABLE) begin
+            $error("Error: TDEST_ROUTE set requires DEST_ENABLE set (instance %m)");
+            $finish;
+        end
+
+        if (S_DEST_WIDTH < CL_M_COUNT) begin
+            $error("Error: S_DEST_WIDTH too small for port count (instance %m)");
+            $finish;
+        end
+    end
+end
+
 reg [CL_M_COUNT-1:0] select_reg = {CL_M_COUNT{1'b0}}, select_ctl, select_next;
 reg drop_reg = 1'b0, drop_ctl, drop_next;
 reg frame_reg = 1'b0, frame_ctl, frame_next;
@@ -133,8 +150,18 @@ always @* begin
 
     if (!frame_reg && s_axis_tvalid && s_axis_tready) begin
         // start of frame, grab select value
-        select_ctl = select;
-        drop_ctl = drop || select >= M_COUNT;
+        if (TDEST_ROUTE) begin
+            if (M_COUNT > 1) begin
+                select_ctl = s_axis_tdest[S_DEST_WIDTH-1:S_DEST_WIDTH-CL_M_COUNT];
+                drop_ctl = s_axis_tdest[S_DEST_WIDTH-1:S_DEST_WIDTH-CL_M_COUNT] >= M_COUNT;
+            end else begin
+                select_ctl = 0;
+                drop_ctl = 1'b0;
+            end
+        end else begin
+            select_ctl = select;
+            drop_ctl = drop || select >= M_COUNT;
+        end
         frame_ctl = 1'b1;
         if (!(s_axis_tready && s_axis_tvalid && s_axis_tlast)) begin
             select_next = select_ctl;
diff --git a/rtl/axis_demux_wrap.py b/rtl/axis_demux_wrap.py
index 682e31b0..344c8bcb 100755
--- a/rtl/axis_demux_wrap.py
+++ b/rtl/axis_demux_wrap.py
@@ -89,7 +89,9 @@ module {{name}} #
     // Propagate tuser signal
     parameter USER_ENABLE = 1,
     // tuser signal width
-    parameter USER_WIDTH = 1
+    parameter USER_WIDTH = 1,
+    // route via tdest
+    parameter TDEST_ROUTE = 0
 )
 (
     input  wire                     clk,
@@ -139,7 +141,8 @@ axis_demux #(
     .S_DEST_WIDTH(S_DEST_WIDTH),
     .M_DEST_WIDTH(M_DEST_WIDTH),
     .USER_ENABLE(USER_ENABLE),
-    .USER_WIDTH(USER_WIDTH)
+    .USER_WIDTH(USER_WIDTH),
+    .TDEST_ROUTE(TDEST_ROUTE)
 )
 axis_demux_inst (
     .clk(clk),
diff --git a/tb/axis_demux/Makefile b/tb/axis_demux/Makefile
index a7e2a1dd..2b081297 100644
--- a/tb/axis_demux/Makefile
+++ b/tb/axis_demux/Makefile
@@ -46,6 +46,7 @@ export PARAM_M_DEST_WIDTH ?= 8
 export PARAM_S_DEST_WIDTH ?= $(shell python -c "print($(PARAM_M_DEST_WIDTH) + ($(PARAM_PORTS)-1).bit_length())")
 export PARAM_USER_ENABLE ?= 1
 export PARAM_USER_WIDTH ?= 1
+export PARAM_TDEST_ROUTE ?= 1
 
 ifeq ($(SIM), icarus)
 	PLUSARGS += -fst
@@ -60,6 +61,7 @@ ifeq ($(SIM), icarus)
 	COMPILE_ARGS += -P $(TOPLEVEL).M_DEST_WIDTH=$(PARAM_M_DEST_WIDTH)
 	COMPILE_ARGS += -P $(TOPLEVEL).USER_ENABLE=$(PARAM_USER_ENABLE)
 	COMPILE_ARGS += -P $(TOPLEVEL).USER_WIDTH=$(PARAM_USER_WIDTH)
+	COMPILE_ARGS += -P $(TOPLEVEL).TDEST_ROUTE=$(PARAM_TDEST_ROUTE)
 
 	ifeq ($(WAVES), 1)
 		VERILOG_SOURCES += iverilog_dump.v
@@ -78,6 +80,7 @@ else ifeq ($(SIM), verilator)
 	COMPILE_ARGS += -GM_DEST_WIDTH=$(PARAM_M_DEST_WIDTH)
 	COMPILE_ARGS += -GUSER_ENABLE=$(PARAM_USER_ENABLE)
 	COMPILE_ARGS += -GUSER_WIDTH=$(PARAM_USER_WIDTH)
+	COMPILE_ARGS += -GTDEST_ROUTE=$(PARAM_TDEST_ROUTE)
 
 	ifeq ($(WAVES), 1)
 		COMPILE_ARGS += --trace-fst
diff --git a/tb/axis_demux/test_axis_demux.py b/tb/axis_demux/test_axis_demux.py
index 1231bf4e..ee997976 100644
--- a/tb/axis_demux/test_axis_demux.py
+++ b/tb/axis_demux/test_axis_demux.py
@@ -82,7 +82,13 @@ async def run_test(dut, payload_lengths=None, payload_data=None, idle_inserter=N
 
     tb = TB(dut)
 
-    id_count = 2**len(tb.source.bus.tid)
+    id_width = len(tb.source.bus.tid)
+    id_count = 2**id_width
+    id_mask = id_count-1
+
+    dest_width = len(tb.sink[0].bus.tid)
+    dest_count = 2**dest_width
+    dest_mask = dest_count-1
 
     cur_id = 1
 
@@ -100,7 +106,7 @@ async def run_test(dut, payload_lengths=None, payload_data=None, idle_inserter=N
     for test_data in [payload_data(x) for x in payload_lengths()]:
         test_frame = AxiStreamFrame(test_data)
         test_frame.tid = cur_id
-        test_frame.tdest = cur_id
+        test_frame.tdest = cur_id | (port << dest_width)
 
         test_frames.append(test_frame)
         await tb.source.send(test_frame)
@@ -112,7 +118,7 @@ async def run_test(dut, payload_lengths=None, payload_data=None, idle_inserter=N
 
         assert rx_frame.tdata == test_frame.tdata
         assert rx_frame.tid == test_frame.tid
-        assert rx_frame.tdest == test_frame.tdest
+        assert rx_frame.tdest == (test_frame.tdest & dest_mask)
         assert not rx_frame.tuser
 
     assert tb.sink[port].empty()
@@ -154,9 +160,10 @@ tests_dir = os.path.dirname(__file__)
 rtl_dir = os.path.abspath(os.path.join(tests_dir, '..', '..', 'rtl'))
 
 
+@pytest.mark.parametrize("tdest_route", [0, 1])
 @pytest.mark.parametrize("data_width", [8, 16, 32])
 @pytest.mark.parametrize("ports", [4])
-def test_axis_demux(request, ports, data_width):
+def test_axis_demux(request, ports, data_width, tdest_route):
     dut = "axis_demux"
     wrapper = f"{dut}_wrap_{ports}"
     module = os.path.splitext(os.path.basename(__file__))[0]
@@ -187,6 +194,7 @@ def test_axis_demux(request, ports, data_width):
     parameters['S_DEST_WIDTH'] = parameters['M_DEST_WIDTH'] + (ports-1).bit_length()
     parameters['USER_ENABLE'] = 1
     parameters['USER_WIDTH'] = 1
+    parameters['TDEST_ROUTE'] = tdest_route
 
     extra_env = {f'PARAM_{k}': str(v) for k, v in parameters.items()}
 
-- 
GitLab