diff --git a/rtl/axis_arb_mux.v b/rtl/axis_arb_mux.v
index a9f5c10551a33019e60598edc0063429d7698e06..84ea0ff9e6fd42099f36868ad800ebd14ade7d98 100644
--- a/rtl/axis_arb_mux.v
+++ b/rtl/axis_arb_mux.v
@@ -57,6 +57,8 @@ module axis_arb_mux #
     parameter USER_WIDTH = 1,
     // Propagate tlast signal
     parameter LAST_ENABLE = 1,
+    // Update tid with routing information
+    parameter UPDATE_TID = 0,
     // select round robin arbitration
     parameter ARB_TYPE_ROUND_ROBIN = 0,
     // LSB priority selection
@@ -93,6 +95,21 @@ module axis_arb_mux #
 
 parameter CL_S_COUNT = $clog2(S_COUNT);
 
+// check configuration
+initial begin
+    if (UPDATE_TID) begin
+        if (!ID_ENABLE) begin
+            $error("Error: UPDATE_TID set requires ID_ENABLE set (instance %m)");
+            $finish;
+        end
+
+        if (M_ID_WIDTH < CL_S_COUNT) begin
+            $error("Error: M_ID_WIDTH too small for port count (instance %m)");
+            $finish;
+        end
+    end
+end
+
 wire [S_COUNT-1:0] request;
 wire [S_COUNT-1:0] acknowledge;
 wire [S_COUNT-1:0] grant;
@@ -150,6 +167,9 @@ always @* begin
     m_axis_tvalid_int = current_s_tvalid && m_axis_tready_int_reg && grant_valid;
     m_axis_tlast_int  = current_s_tlast;
     m_axis_tid_int    = current_s_tid;
+    if (UPDATE_TID && S_COUNT > 1) begin
+        m_axis_tid_int[M_ID_WIDTH-1:M_ID_WIDTH-CL_S_COUNT] = grant_encoded;
+    end
     m_axis_tdest_int  = current_s_tdest;
     m_axis_tuser_int  = current_s_tuser;
 end
diff --git a/rtl/axis_arb_mux_wrap.py b/rtl/axis_arb_mux_wrap.py
index 7e02d207fbb6028bb745e4f8f5c27e73ac298bc0..091d36612846bd21fe8087c8f48b6b6e03a81e8e 100755
--- a/rtl/axis_arb_mux_wrap.py
+++ b/rtl/axis_arb_mux_wrap.py
@@ -92,6 +92,8 @@ module {{name}} #
     parameter USER_WIDTH = 1,
     // Propagate tlast signal
     parameter LAST_ENABLE = 1,
+    // Update tid with routing information
+    parameter UPDATE_TID = 0,
     // select round robin arbitration
     parameter ARB_TYPE_ROUND_ROBIN = 0,
     // LSB priority selection
@@ -140,6 +142,7 @@ axis_arb_mux #(
     .USER_ENABLE(USER_ENABLE),
     .USER_WIDTH(USER_WIDTH),
     .LAST_ENABLE(LAST_ENABLE),
+    .UPDATE_TID(UPDATE_TID),
     .ARB_TYPE_ROUND_ROBIN(ARB_TYPE_ROUND_ROBIN),
     .ARB_LSB_HIGH_PRIORITY(ARB_LSB_HIGH_PRIORITY)
 )
diff --git a/rtl/axis_ram_switch.v b/rtl/axis_ram_switch.v
index a88184161b6a4490f8a5a96aae9f75cb1b845ccb..4b75c089decd156088eb935bc7e3c69950e63011 100644
--- a/rtl/axis_ram_switch.v
+++ b/rtl/axis_ram_switch.v
@@ -96,6 +96,8 @@ module axis_ram_switch #
     // Interface connection control
     // M_COUNT concatenated fields of S_COUNT bits
     parameter M_CONNECT = {M_COUNT{{S_COUNT{1'b1}}}},
+    // Update tid with routing information
+    parameter UPDATE_TID = 0,
     // select round robin arbitration
     parameter ARB_TYPE_ROUND_ROBIN = 1,
     // LSB priority selection
@@ -188,6 +190,18 @@ initial begin
         $finish;
     end
 
+    if (UPDATE_TID) begin
+        if (!ID_ENABLE) begin
+            $error("Error: UPDATE_TID set requires ID_ENABLE set (instance %m)");
+            $finish;
+        end
+
+        if (M_ID_WIDTH < CL_S_COUNT) begin
+            $error("Error: M_ID_WIDTH too small for port count (instance %m)");
+            $finish;
+        end
+    end
+
     if (M_BASE == 0) begin
         // M_BASE is zero, route with tdest as port index
         $display("Addressing configuration for axis_switch instance %m");
@@ -825,16 +839,30 @@ generate
         );
 
         // mux
-        wire [RAM_ADDR_WIDTH-1:0] cmd_addr_mux  = int_cmd_addr[grant_encoded*RAM_ADDR_WIDTH +: RAM_ADDR_WIDTH];
-        wire [ADDR_WIDTH-1:0]     cmd_len_mux   = int_cmd_len[grant_encoded*ADDR_WIDTH +: ADDR_WIDTH];
-        wire [CMD_ADDR_WIDTH-1:0] cmd_id_mux    = int_cmd_id[grant_encoded*CMD_ADDR_WIDTH +: CMD_ADDR_WIDTH];
-        wire [KEEP_WIDTH-1:0]     cmd_tkeep_mux = int_cmd_tkeep[grant_encoded*KEEP_WIDTH +: KEEP_WIDTH];
-        wire [M_ID_WIDTH-1:0]     cmd_tid_mux   = int_cmd_tid[grant_encoded*S_ID_WIDTH +: S_ID_WIDTH];
-        wire [M_DEST_WIDTH-1:0]   cmd_tdest_mux = int_cmd_tdest[grant_encoded*S_DEST_WIDTH +: S_DEST_WIDTH];
-        wire [USER_WIDTH-1:0]     cmd_tuser_mux = int_cmd_tuser[grant_encoded*USER_WIDTH +: USER_WIDTH];
-        wire                      cmd_valid_mux = int_cmd_valid[grant_encoded*M_COUNT+n] && grant_valid;
+        reg  [RAM_ADDR_WIDTH-1:0] cmd_addr_mux;
+        reg  [ADDR_WIDTH-1:0]     cmd_len_mux;
+        reg  [CMD_ADDR_WIDTH-1:0] cmd_id_mux;
+        reg  [KEEP_WIDTH-1:0]     cmd_tkeep_mux;
+        reg  [M_ID_WIDTH-1:0]     cmd_tid_mux;
+        reg  [M_DEST_WIDTH-1:0]   cmd_tdest_mux;
+        reg  [USER_WIDTH-1:0]     cmd_tuser_mux;
+        reg                       cmd_valid_mux;
         wire                      cmd_ready_mux;
 
+        always @* begin
+            cmd_addr_mux  = int_cmd_addr[grant_encoded*RAM_ADDR_WIDTH +: RAM_ADDR_WIDTH];
+            cmd_len_mux   = int_cmd_len[grant_encoded*ADDR_WIDTH +: ADDR_WIDTH];
+            cmd_id_mux    = int_cmd_id[grant_encoded*CMD_ADDR_WIDTH +: CMD_ADDR_WIDTH];
+            cmd_tkeep_mux = int_cmd_tkeep[grant_encoded*KEEP_WIDTH +: KEEP_WIDTH];
+            cmd_tid_mux   = int_cmd_tid[grant_encoded*S_ID_WIDTH +: S_ID_WIDTH];
+            if (UPDATE_TID && S_COUNT > 1) begin
+                cmd_tid_mux[M_ID_WIDTH-1:M_ID_WIDTH-CL_S_COUNT] = grant_encoded;
+            end
+            cmd_tdest_mux = int_cmd_tdest[grant_encoded*S_DEST_WIDTH +: S_DEST_WIDTH];
+            cmd_tuser_mux = int_cmd_tuser[grant_encoded*USER_WIDTH +: USER_WIDTH];
+            cmd_valid_mux = int_cmd_valid[grant_encoded*M_COUNT+n] && grant_valid;
+        end
+
         assign int_cmd_ready[n*S_COUNT +: S_COUNT] = (grant_valid && cmd_ready_mux) << grant_encoded;
 
         for (m = 0; m < S_COUNT; m = m + 1) begin
diff --git a/rtl/axis_ram_switch_wrap.py b/rtl/axis_ram_switch_wrap.py
index 4d8e25b29ac5c85c4b5328dbc791ab7e188f73da..45adfed5bf389e29f435c06185dd2f109953f5a1 100755
--- a/rtl/axis_ram_switch_wrap.py
+++ b/rtl/axis_ram_switch_wrap.py
@@ -132,6 +132,8 @@ module {{name}} #
     // Interface connection control
     parameter M{{'%02d'%p}}_CONNECT = {{m}}'b{% for p in range(m) %}1{% endfor %},
 {%- endfor %}
+    // Update tid with routing information
+    parameter UPDATE_TID = 0,
     // select round robin arbitration
     parameter ARB_TYPE_ROUND_ROBIN = 1,
     // LSB priority selection
@@ -212,6 +214,7 @@ axis_ram_switch #(
     .M_BASE({ {% for p in range(n-1,-1,-1) %}w_dw(M{{'%02d'%p}}_BASE){% if not loop.last %}, {% endif %}{% endfor %} }),
     .M_TOP({ {% for p in range(n-1,-1,-1) %}w_dw(M{{'%02d'%p}}_TOP){% if not loop.last %}, {% endif %}{% endfor %} }),
     .M_CONNECT({ {% for p in range(n-1,-1,-1) %}w_s(M{{'%02d'%p}}_CONNECT){% if not loop.last %}, {% endif %}{% endfor %} }),
+    .UPDATE_TID(UPDATE_TID),
     .ARB_TYPE_ROUND_ROBIN(ARB_TYPE_ROUND_ROBIN),
     .ARB_LSB_HIGH_PRIORITY(ARB_LSB_HIGH_PRIORITY),
     .RAM_PIPELINE(RAM_PIPELINE)
diff --git a/rtl/axis_switch.v b/rtl/axis_switch.v
index f65368f72eaf2b13f8558ea6bd97f28f39480ad4..7274d5478edc861ffe3c3e09214816fc3a37252c 100644
--- a/rtl/axis_switch.v
+++ b/rtl/axis_switch.v
@@ -71,6 +71,8 @@ module axis_switch #
     // Interface connection control
     // M_COUNT concatenated fields of S_COUNT bits
     parameter M_CONNECT = {M_COUNT{{S_COUNT{1'b1}}}},
+    // Update tid with routing information
+    parameter UPDATE_TID = 0,
     // Input interface register type
     // 0 to bypass, 1 for simple buffer, 2 for skid buffer
     parameter S_REG_TYPE = 0,
@@ -123,6 +125,18 @@ initial begin
         $finish;
     end
 
+    if (UPDATE_TID) begin
+        if (!ID_ENABLE) begin
+            $error("Error: UPDATE_TID set requires ID_ENABLE set (instance %m)");
+            $finish;
+        end
+
+        if (M_ID_WIDTH < CL_S_COUNT) begin
+            $error("Error: M_ID_WIDTH too small for port count (instance %m)");
+            $finish;
+        end
+    end
+
     if (M_BASE == 0) begin
         // M_BASE is zero, route with tdest as port index
         $display("Addressing configuration for axis_switch instance %m");
@@ -320,14 +334,27 @@ generate
         );
 
         // mux
-        wire [DATA_WIDTH-1:0]    m_axis_tdata_mux   = int_s_axis_tdata[grant_encoded*DATA_WIDTH +: DATA_WIDTH];
-        wire [KEEP_WIDTH-1:0]    m_axis_tkeep_mux   = int_s_axis_tkeep[grant_encoded*KEEP_WIDTH +: KEEP_WIDTH];
-        wire                     m_axis_tvalid_mux  = int_axis_tvalid[grant_encoded*M_COUNT+n] && grant_valid;
+        reg  [DATA_WIDTH-1:0]    m_axis_tdata_mux;
+        reg  [KEEP_WIDTH-1:0]    m_axis_tkeep_mux;
+        reg                      m_axis_tvalid_mux;
         wire                     m_axis_tready_mux;
-        wire                     m_axis_tlast_mux   = int_s_axis_tlast[grant_encoded];
-        wire [M_ID_WIDTH-1:0]    m_axis_tid_mux     = int_s_axis_tid[grant_encoded*S_ID_WIDTH +: S_ID_WIDTH];
-        wire [M_DEST_WIDTH-1:0]  m_axis_tdest_mux   = int_s_axis_tdest[grant_encoded*S_DEST_WIDTH +: S_DEST_WIDTH];
-        wire [USER_WIDTH-1:0]    m_axis_tuser_mux   = int_s_axis_tuser[grant_encoded*USER_WIDTH +: USER_WIDTH];
+        reg                      m_axis_tlast_mux;
+        reg  [M_ID_WIDTH-1:0]    m_axis_tid_mux;
+        reg  [M_DEST_WIDTH-1:0]  m_axis_tdest_mux;
+        reg  [USER_WIDTH-1:0]    m_axis_tuser_mux;
+
+        always @* begin
+            m_axis_tdata_mux   = int_s_axis_tdata[grant_encoded*DATA_WIDTH +: DATA_WIDTH];
+            m_axis_tkeep_mux   = int_s_axis_tkeep[grant_encoded*KEEP_WIDTH +: KEEP_WIDTH];
+            m_axis_tvalid_mux  = int_axis_tvalid[grant_encoded*M_COUNT+n] && grant_valid;
+            m_axis_tlast_mux   = int_s_axis_tlast[grant_encoded];
+            m_axis_tid_mux     = int_s_axis_tid[grant_encoded*S_ID_WIDTH +: S_ID_WIDTH];
+            if (UPDATE_TID && S_COUNT > 1) begin
+                m_axis_tid_mux[M_ID_WIDTH-1:M_ID_WIDTH-CL_S_COUNT] = grant_encoded;
+            end
+            m_axis_tdest_mux   = int_s_axis_tdest[grant_encoded*S_DEST_WIDTH +: S_DEST_WIDTH];
+            m_axis_tuser_mux   = int_s_axis_tuser[grant_encoded*USER_WIDTH +: USER_WIDTH];
+        end
 
         assign int_axis_tready[n*S_COUNT +: S_COUNT] = (grant_valid && m_axis_tready_mux) << grant_encoded;
 
diff --git a/rtl/axis_switch_wrap.py b/rtl/axis_switch_wrap.py
index 87efd5494ba2b7a48731adbe7a3d316bd13581fc..0af37e02acbfd44393c59679191b1fa78383afd4 100755
--- a/rtl/axis_switch_wrap.py
+++ b/rtl/axis_switch_wrap.py
@@ -107,6 +107,8 @@ module {{name}} #
     // Interface connection control
     parameter M{{'%02d'%p}}_CONNECT = {{m}}'b{% for p in range(m) %}1{% endfor %},
 {%- endfor %}
+    // Update tid with routing information
+    parameter UPDATE_TID = 0,
     // Input interface register type
     // 0 to bypass, 1 for simple buffer, 2 for skid buffer
     parameter S_REG_TYPE = 0,
@@ -175,6 +177,7 @@ axis_switch #(
     .M_BASE({ {% for p in range(n-1,-1,-1) %}w_dw(M{{'%02d'%p}}_BASE){% if not loop.last %}, {% endif %}{% endfor %} }),
     .M_TOP({ {% for p in range(n-1,-1,-1) %}w_dw(M{{'%02d'%p}}_TOP){% if not loop.last %}, {% endif %}{% endfor %} }),
     .M_CONNECT({ {% for p in range(n-1,-1,-1) %}w_s(M{{'%02d'%p}}_CONNECT){% if not loop.last %}, {% endif %}{% endfor %} }),
+    .UPDATE_TID(UPDATE_TID),
     .S_REG_TYPE(S_REG_TYPE),
     .M_REG_TYPE(M_REG_TYPE),
     .ARB_TYPE_ROUND_ROBIN(ARB_TYPE_ROUND_ROBIN),
diff --git a/tb/axis_arb_mux/Makefile b/tb/axis_arb_mux/Makefile
index 7d91e58974f2ade0cb0003028043466b67e21daa..b4448a5a6f599a3e8e1f18eecef5ed6094ea9540 100644
--- a/tb/axis_arb_mux/Makefile
+++ b/tb/axis_arb_mux/Makefile
@@ -49,6 +49,7 @@ export PARAM_DEST_WIDTH ?= 8
 export PARAM_USER_ENABLE ?= 1
 export PARAM_USER_WIDTH ?= 1
 export PARAM_LAST_ENABLE ?= 1
+export PARAM_UPDATE_TID ?= 1
 export PARAM_ARB_TYPE_ROUND_ROBIN ?= 0
 export PARAM_ARB_LSB_HIGH_PRIORITY ?= 1
 
@@ -66,6 +67,7 @@ ifeq ($(SIM), icarus)
 	COMPILE_ARGS += -P $(TOPLEVEL).USER_ENABLE=$(PARAM_USER_ENABLE)
 	COMPILE_ARGS += -P $(TOPLEVEL).USER_WIDTH=$(PARAM_USER_WIDTH)
 	COMPILE_ARGS += -P $(TOPLEVEL).LAST_ENABLE=$(PARAM_LAST_ENABLE)
+	COMPILE_ARGS += -P $(TOPLEVEL).UPDATE_TID=$(PARAM_UPDATE_TID)
 	COMPILE_ARGS += -P $(TOPLEVEL).ARB_TYPE_ROUND_ROBIN=$(PARAM_ARB_TYPE_ROUND_ROBIN)
 	COMPILE_ARGS += -P $(TOPLEVEL).ARB_LSB_HIGH_PRIORITY=$(PARAM_ARB_LSB_HIGH_PRIORITY)
 
@@ -87,6 +89,7 @@ else ifeq ($(SIM), verilator)
 	COMPILE_ARGS += -GUSER_ENABLE=$(PARAM_USER_ENABLE)
 	COMPILE_ARGS += -GUSER_WIDTH=$(PARAM_USER_WIDTH)
 	COMPILE_ARGS += -GLAST_ENABLE=$(PARAM_LAST_ENABLE)
+	COMPILE_ARGS += -GUPDATE_TID=$(PARAM_UPDATE_TID)
 	COMPILE_ARGS += -GARB_TYPE_ROUND_ROBIN=$(PARAM_ARB_TYPE_ROUND_ROBIN)
 	COMPILE_ARGS += -GARB_LSB_HIGH_PRIORITY=$(PARAM_ARB_LSB_HIGH_PRIORITY)
 
diff --git a/tb/axis_arb_mux/test_axis_arb_mux.py b/tb/axis_arb_mux/test_axis_arb_mux.py
index d69b0b156094f73d578de87b9f8fcfe48d7cef54..eb9fb334dd071554650d06964ba2855d9f9f2aed 100644
--- a/tb/axis_arb_mux/test_axis_arb_mux.py
+++ b/tb/axis_arb_mux/test_axis_arb_mux.py
@@ -79,7 +79,15 @@ async def run_test(dut, payload_lengths=None, payload_data=None, idle_inserter=N
 
     tb = TB(dut)
 
-    id_count = 2**len(tb.source[port].bus.tid)
+    id_width = len(tb.source[0].bus.tid)
+    id_count = 2**id_width
+    id_mask = id_count-1
+
+    src_width = (len(tb.source)-1).bit_length()
+    src_mask = 2**src_width-1 if src_width else 0
+    src_shift = id_width-src_width
+    max_count = 2**src_shift
+    count_mask = max_count-1
 
     cur_id = 1
 
@@ -92,19 +100,21 @@ async def run_test(dut, payload_lengths=None, payload_data=None, idle_inserter=N
 
     for test_data in [payload_data(x) for x in payload_lengths()]:
         test_frame = AxiStreamFrame(test_data)
-        test_frame.tid = cur_id
+        test_frame.tid = cur_id | (port << src_shift)
         test_frame.tdest = cur_id
 
         test_frames.append(test_frame)
         await tb.source[port].send(test_frame)
 
-        cur_id = (cur_id + 1) % id_count
+        cur_id = (cur_id + 1) % max_count
 
     for test_frame in test_frames:
         rx_frame = await tb.sink.recv()
 
         assert rx_frame.tdata == test_frame.tdata
-        assert rx_frame.tid == test_frame.tid
+        assert (rx_frame.tid & id_mask) == test_frame.tid
+        assert ((rx_frame.tid >> src_shift) & src_mask) == port
+        assert (rx_frame.tid >> id_width) == port
         assert rx_frame.tdest == test_frame.tdest
         assert not rx_frame.tuser
 
@@ -140,7 +150,15 @@ async def run_arb_test(dut):
     tb = TB(dut)
 
     byte_lanes = tb.source[0].byte_lanes
-    id_count = 2**len(tb.source[0].bus.tid)
+    id_width = len(tb.source[0].bus.tid)
+    id_count = 2**id_width
+    id_mask = id_count-1
+
+    src_width = (len(tb.source)-1).bit_length()
+    src_mask = 2**src_width-1 if src_width else 0
+    src_shift = id_width-src_width
+    max_count = 2**src_shift
+    count_mask = max_count-1
 
     cur_id = 1
 
@@ -153,29 +171,34 @@ async def run_arb_test(dut):
 
     for k in range(5):
         test_frame = AxiStreamFrame(test_data, tx_complete=Event())
-        test_frame.tid = cur_id
+
+        src_ind = 0
 
         if k == 0:
-            test_frame.tdest = 0
+            src_ind = 0
         elif k == 4:
             await test_frames[1].tx_complete.wait()
             for j in range(8):
                 await RisingEdge(dut.clk)
-            test_frame.tdest = 0
+            src_ind = 0
         else:
-            test_frame.tdest = 1
+            src_ind = 1
+
+        test_frame.tid = cur_id | (src_ind << src_shift)
+        test_frame.tdest = 0
 
         test_frames.append(test_frame)
-        await tb.source[test_frame.tdest].send(test_frame)
+        await tb.source[src_ind].send(test_frame)
 
-        cur_id = (cur_id + 1) % id_count
+        cur_id = (cur_id + 1) % max_count
 
     for k in [0, 1, 2, 4, 3]:
         test_frame = test_frames[k]
         rx_frame = await tb.sink.recv()
 
         assert rx_frame.tdata == test_frame.tdata
-        assert rx_frame.tid == test_frame.tid
+        assert (rx_frame.tid & id_mask) == test_frame.tid
+        assert ((rx_frame.tid >> src_shift) & src_mask) == (rx_frame.tid >> id_width)
         assert rx_frame.tdest == test_frame.tdest
         assert not rx_frame.tuser
 
@@ -190,7 +213,15 @@ async def run_stress_test(dut, idle_inserter=None, backpressure_inserter=None):
     tb = TB(dut)
 
     byte_lanes = tb.source[0].byte_lanes
-    id_count = 2**len(tb.source[0].bus.tid)
+    id_width = len(tb.source[0].bus.tid)
+    id_count = 2**id_width
+    id_mask = id_count-1
+
+    src_width = (len(tb.source)-1).bit_length()
+    src_mask = 2**src_width-1 if src_width else 0
+    src_shift = id_width-src_width
+    max_count = 2**src_shift
+    count_mask = max_count-1
 
     cur_id = 1
 
@@ -206,13 +237,13 @@ async def run_stress_test(dut, idle_inserter=None, backpressure_inserter=None):
             length = random.randint(1, byte_lanes*16)
             test_data = bytearray(itertools.islice(itertools.cycle(range(256)), length))
             test_frame = AxiStreamFrame(test_data)
-            test_frame.tid = p
+            test_frame.tid = cur_id | (p << src_shift)
             test_frame.tdest = cur_id
 
             test_frames[p].append(test_frame)
             await tb.source[p].send(test_frame)
 
-            cur_id = (cur_id + 1) % id_count
+            cur_id = (cur_id + 1) % max_count
 
     while any(test_frames):
         rx_frame = await tb.sink.recv()
@@ -220,14 +251,15 @@ async def run_stress_test(dut, idle_inserter=None, backpressure_inserter=None):
         test_frame = None
 
         for lst in test_frames:
-            if lst and lst[0].tid == rx_frame.tid:
+            if lst and lst[0].tid == (rx_frame.tid & id_mask):
                 test_frame = lst.pop(0)
                 break
 
         assert test_frame is not None
 
         assert rx_frame.tdata == test_frame.tdata
-        assert rx_frame.tid == test_frame.tid
+        assert (rx_frame.tid & id_mask) == test_frame.tid
+        assert ((rx_frame.tid >> src_shift) & src_mask) == (rx_frame.tid >> id_width)
         assert rx_frame.tdest == test_frame.tdest
         assert not rx_frame.tuser
 
@@ -323,6 +355,7 @@ def test_axis_arb_mux(request, ports, data_width, round_robin):
     parameters['USER_ENABLE'] = 1
     parameters['USER_WIDTH'] = 1
     parameters['LAST_ENABLE'] = 1
+    parameters['UPDATE_TID'] = 1
     parameters['ARB_TYPE_ROUND_ROBIN'] = round_robin
     parameters['ARB_LSB_HIGH_PRIORITY'] = 1
 
diff --git a/tb/axis_ram_switch/Makefile b/tb/axis_ram_switch/Makefile
index 49866a772443351933630b17d948865b1c02dfbf..6f8ebfd20da917a3346e9a6ad2302d2058f1ac8a 100644
--- a/tb/axis_ram_switch/Makefile
+++ b/tb/axis_ram_switch/Makefile
@@ -60,6 +60,7 @@ export PARAM_USER_BAD_FRAME_VALUE ?= 1
 export PARAM_USER_BAD_FRAME_MASK ?= 1
 export PARAM_DROP_BAD_FRAME ?= 0
 export PARAM_DROP_WHEN_FULL ?= 0
+export PARAM_UPDATE_TID ?= 1
 export PARAM_ARB_TYPE_ROUND_ROBIN ?= 1
 export PARAM_ARB_LSB_HIGH_PRIORITY ?= 1
 export PARAM_RAM_PIPELINE ?= 2
@@ -87,6 +88,7 @@ ifeq ($(SIM), icarus)
 	COMPILE_ARGS += -P $(TOPLEVEL).USER_BAD_FRAME_MASK=$(PARAM_USER_BAD_FRAME_MASK)
 	COMPILE_ARGS += -P $(TOPLEVEL).DROP_BAD_FRAME=$(PARAM_DROP_BAD_FRAME)
 	COMPILE_ARGS += -P $(TOPLEVEL).DROP_WHEN_FULL=$(PARAM_DROP_WHEN_FULL)
+	COMPILE_ARGS += -P $(TOPLEVEL).UPDATE_TID=$(PARAM_UPDATE_TID)
 	COMPILE_ARGS += -P $(TOPLEVEL).ARB_TYPE_ROUND_ROBIN=$(PARAM_ARB_TYPE_ROUND_ROBIN)
 	COMPILE_ARGS += -P $(TOPLEVEL).ARB_LSB_HIGH_PRIORITY=$(PARAM_ARB_LSB_HIGH_PRIORITY)
 	COMPILE_ARGS += -P $(TOPLEVEL).RAM_PIPELINE=$(PARAM_RAM_PIPELINE)
@@ -118,6 +120,7 @@ else ifeq ($(SIM), verilator)
 	COMPILE_ARGS += -GUSER_BAD_FRAME_MASK=$(PARAM_USER_BAD_FRAME_MASK)
 	COMPILE_ARGS += -GDROP_BAD_FRAME=$(PARAM_DROP_BAD_FRAME)
 	COMPILE_ARGS += -GDROP_WHEN_FULL=$(PARAM_DROP_WHEN_FULL)
+	COMPILE_ARGS += -GUPDATE_TID=$(PARAM_UPDATE_TID)
 	COMPILE_ARGS += -GARB_TYPE_ROUND_ROBIN=$(PARAM_ARB_TYPE_ROUND_ROBIN)
 	COMPILE_ARGS += -GARB_LSB_HIGH_PRIORITY=$(PARAM_ARB_LSB_HIGH_PRIORITY)
 	COMPILE_ARGS += -GRAM_PIPELINE=$(PARAM_RAM_PIPELINE)
diff --git a/tb/axis_ram_switch/test_axis_ram_switch.py b/tb/axis_ram_switch/test_axis_ram_switch.py
index 16380a3913d92c503b5e16e3797cbfc715957bf3..e66c4e99ea39e1b8db85b1f77578256713a69987 100644
--- a/tb/axis_ram_switch/test_axis_ram_switch.py
+++ b/tb/axis_ram_switch/test_axis_ram_switch.py
@@ -81,7 +81,15 @@ async def run_test(dut, payload_lengths=None, payload_data=None, idle_inserter=N
 
     tb = TB(dut)
 
-    id_count = 2**len(tb.source[s].bus.tid)
+    id_width = len(tb.source[0].bus.tid)
+    id_count = 2**id_width
+    id_mask = id_count-1
+
+    src_width = (len(tb.source)-1).bit_length()
+    src_mask = 2**src_width-1 if src_width else 0
+    src_shift = id_width-src_width
+    max_count = 2**src_shift
+    count_mask = max_count-1
 
     cur_id = 1
 
@@ -94,19 +102,21 @@ async def run_test(dut, payload_lengths=None, payload_data=None, idle_inserter=N
 
     for test_data in [payload_data(x) for x in payload_lengths()]:
         test_frame = AxiStreamFrame(test_data)
-        test_frame.tid = cur_id
+        test_frame.tid = cur_id | (s << src_shift)
         test_frame.tdest = m
 
         test_frames.append(test_frame)
         await tb.source[s].send(test_frame)
 
-        cur_id = (cur_id + 1) % id_count
+        cur_id = (cur_id + 1) % max_count
 
     for test_frame in test_frames:
         rx_frame = await tb.sink[m].recv()
 
         assert rx_frame.tdata == test_frame.tdata
-        assert rx_frame.tid == test_frame.tid
+        assert (rx_frame.tid & id_mask) == test_frame.tid
+        assert ((rx_frame.tid >> src_shift) & src_mask) == s
+        assert (rx_frame.tid >> id_width) == s
         assert rx_frame.tdest == test_frame.tdest
         assert not rx_frame.tuser
 
@@ -142,7 +152,15 @@ async def run_arb_test(dut):
     tb = TB(dut)
 
     byte_lanes = max(tb.source[0].byte_lanes, tb.sink[0].byte_lanes)
-    id_count = 2**len(tb.source[0].bus.tid)
+    id_width = len(tb.source[0].bus.tid)
+    id_count = 2**id_width
+    id_mask = id_count-1
+
+    src_width = (len(tb.source)-1).bit_length()
+    src_mask = 2**src_width-1 if src_width else 0
+    src_shift = id_width-src_width
+    max_count = 2**src_shift
+    count_mask = max_count-1
 
     cur_id = 1
 
@@ -155,8 +173,6 @@ async def run_arb_test(dut):
 
     for k in range(5):
         test_frame = AxiStreamFrame(test_data, tx_complete=Event())
-        test_frame.tid = cur_id
-        test_frame.tdest = 0
 
         src_ind = 0
 
@@ -170,17 +186,21 @@ async def run_arb_test(dut):
         else:
             src_ind = 1
 
+        test_frame.tid = cur_id | (src_ind << src_shift)
+        test_frame.tdest = 0
+
         test_frames.append(test_frame)
         await tb.source[src_ind].send(test_frame)
 
-        cur_id = (cur_id + 1) % id_count
+        cur_id = (cur_id + 1) % max_count
 
     for k in [0, 1, 2, 4, 3]:
         test_frame = test_frames[k]
         rx_frame = await tb.sink[0].recv()
 
         assert rx_frame.tdata == test_frame.tdata
-        assert rx_frame.tid == test_frame.tid
+        assert (rx_frame.tid & id_mask) == test_frame.tid
+        assert ((rx_frame.tid >> src_shift) & src_mask) == (rx_frame.tid >> id_width)
         assert rx_frame.tdest == test_frame.tdest
         assert not rx_frame.tuser
 
@@ -195,7 +215,15 @@ async def run_stress_test(dut, idle_inserter=None, backpressure_inserter=None):
     tb = TB(dut)
 
     byte_lanes = max(tb.source[0].byte_lanes, tb.sink[0].byte_lanes)
-    id_count = 2**len(tb.source[0].bus.tid)
+    id_width = len(tb.source[0].bus.tid)
+    id_count = 2**id_width
+    id_mask = id_count-1
+
+    src_width = (len(tb.source)-1).bit_length()
+    src_mask = 2**src_width-1 if src_width else 0
+    src_shift = id_width-src_width
+    max_count = 2**src_shift
+    count_mask = max_count-1
 
     cur_id = 1
 
@@ -211,13 +239,13 @@ async def run_stress_test(dut, idle_inserter=None, backpressure_inserter=None):
             length = random.randint(1, byte_lanes*16)
             test_data = bytearray(itertools.islice(itertools.cycle(range(256)), length))
             test_frame = AxiStreamFrame(test_data)
-            test_frame.tid = cur_id
+            test_frame.tid = cur_id | (p << src_shift)
             test_frame.tdest = random.randrange(len(tb.sink))
 
             test_frames[p][test_frame.tdest].append(test_frame)
             await tb.source[p].send(test_frame)
 
-            cur_id = (cur_id + 1) % id_count
+            cur_id = (cur_id + 1) % max_count
 
     for lst in test_frames:
         while any(lst):
@@ -227,14 +255,15 @@ async def run_stress_test(dut, idle_inserter=None, backpressure_inserter=None):
 
             for lst_a in test_frames:
                 for lst_b in lst_a:
-                    if lst_b and lst_b[0].tid == rx_frame.tid:
+                    if lst_b and lst_b[0].tid == (rx_frame.tid & id_mask):
                         test_frame = lst_b.pop(0)
                         break
 
             assert test_frame is not None
 
             assert rx_frame.tdata == test_frame.tdata
-            assert rx_frame.tid == test_frame.tid
+            assert (rx_frame.tid & id_mask) == test_frame.tid
+            assert ((rx_frame.tid >> src_shift) & src_mask) == (rx_frame.tid >> id_width)
             assert rx_frame.tdest == test_frame.tdest
             assert not rx_frame.tuser
 
@@ -345,6 +374,7 @@ def test_axis_ram_switch(request, s_count, m_count, s_data_width, m_data_width):
     parameters['USER_BAD_FRAME_MASK'] = 1
     parameters['DROP_BAD_FRAME'] = 0
     parameters['DROP_WHEN_FULL'] = 0
+    parameters['UPDATE_TID'] = 1
     parameters['ARB_TYPE_ROUND_ROBIN'] = 1
     parameters['ARB_LSB_HIGH_PRIORITY'] = 1
     parameters['RAM_PIPELINE'] = 2
diff --git a/tb/axis_switch/Makefile b/tb/axis_switch/Makefile
index 115ea325582939c4aa9bb269aad5d5964905615c..77501e3bb619824e895116ef06c2a969da2d1232 100644
--- a/tb/axis_switch/Makefile
+++ b/tb/axis_switch/Makefile
@@ -50,6 +50,7 @@ export PARAM_M_DEST_WIDTH ?= 8
 export PARAM_S_DEST_WIDTH ?= $(shell python -c "print($(PARAM_M_DEST_WIDTH) + ($(PARAM_M_COUNT)-1).bit_length())")
 export PARAM_USER_ENABLE ?= 1
 export PARAM_USER_WIDTH ?= 1
+export PARAM_UPDATE_TID ?= 1
 export PARAM_S_REG_TYPE ?= 0
 export PARAM_M_REG_TYPE ?= 2
 export PARAM_ARB_TYPE_ROUND_ROBIN ?= 1
@@ -68,6 +69,7 @@ ifeq ($(SIM), icarus)
 	COMPILE_ARGS += -P $(TOPLEVEL).M_DEST_WIDTH=$(PARAM_M_DEST_WIDTH)
 	COMPILE_ARGS += -P $(TOPLEVEL).USER_ENABLE=$(PARAM_USER_ENABLE)
 	COMPILE_ARGS += -P $(TOPLEVEL).USER_WIDTH=$(PARAM_USER_WIDTH)
+	COMPILE_ARGS += -P $(TOPLEVEL).UPDATE_TID=$(PARAM_UPDATE_TID)
 	COMPILE_ARGS += -P $(TOPLEVEL).S_REG_TYPE=$(PARAM_S_REG_TYPE)
 	COMPILE_ARGS += -P $(TOPLEVEL).M_REG_TYPE=$(PARAM_M_REG_TYPE)
 	COMPILE_ARGS += -P $(TOPLEVEL).ARB_TYPE_ROUND_ROBIN=$(PARAM_ARB_TYPE_ROUND_ROBIN)
@@ -90,6 +92,7 @@ else ifeq ($(SIM), verilator)
 	COMPILE_ARGS += -GM_DEST_WIDTH=$(PARAM_M_DEST_WIDTH)
 	COMPILE_ARGS += -GUSER_ENABLE=$(PARAM_USER_ENABLE)
 	COMPILE_ARGS += -GUSER_WIDTH=$(PARAM_USER_WIDTH)
+	COMPILE_ARGS += -GUPDATE_TID=$(PARAM_UPDATE_TID)
 	COMPILE_ARGS += -GS_REG_TYPE=$(PARAM_S_REG_TYPE)
 	COMPILE_ARGS += -GM_REG_TYPE=$(PARAM_M_REG_TYPE)
 	COMPILE_ARGS += -GARB_TYPE_ROUND_ROBIN=$(PARAM_ARB_TYPE_ROUND_ROBIN)
diff --git a/tb/axis_switch/test_axis_switch.py b/tb/axis_switch/test_axis_switch.py
index 3097836f365778f7726d7ad19ada518ef35891bf..3cbf32d71b6844f44614691d8ef7625ff086b5f8 100644
--- a/tb/axis_switch/test_axis_switch.py
+++ b/tb/axis_switch/test_axis_switch.py
@@ -81,7 +81,15 @@ async def run_test(dut, payload_lengths=None, payload_data=None, idle_inserter=N
 
     tb = TB(dut)
 
-    id_count = 2**len(tb.source[s].bus.tid)
+    id_width = len(tb.source[0].bus.tid)
+    id_count = 2**id_width
+    id_mask = id_count-1
+
+    src_width = (len(tb.source)-1).bit_length()
+    src_mask = 2**src_width-1 if src_width else 0
+    src_shift = id_width-src_width
+    max_count = 2**src_shift
+    count_mask = max_count-1
 
     cur_id = 1
 
@@ -94,19 +102,21 @@ async def run_test(dut, payload_lengths=None, payload_data=None, idle_inserter=N
 
     for test_data in [payload_data(x) for x in payload_lengths()]:
         test_frame = AxiStreamFrame(test_data)
-        test_frame.tid = cur_id
+        test_frame.tid = cur_id | (s << src_shift)
         test_frame.tdest = m
 
         test_frames.append(test_frame)
         await tb.source[s].send(test_frame)
 
-        cur_id = (cur_id + 1) % id_count
+        cur_id = (cur_id + 1) % max_count
 
     for test_frame in test_frames:
         rx_frame = await tb.sink[m].recv()
 
         assert rx_frame.tdata == test_frame.tdata
-        assert rx_frame.tid == test_frame.tid
+        assert (rx_frame.tid & id_mask) == test_frame.tid
+        assert ((rx_frame.tid >> src_shift) & src_mask) == s
+        assert (rx_frame.tid >> id_width) == s
         assert rx_frame.tdest == test_frame.tdest
         assert not rx_frame.tuser
 
@@ -142,7 +152,15 @@ async def run_arb_test(dut):
     tb = TB(dut)
 
     byte_lanes = tb.source[0].byte_lanes
-    id_count = 2**len(tb.source[0].bus.tid)
+    id_width = len(tb.source[0].bus.tid)
+    id_count = 2**id_width
+    id_mask = id_count-1
+
+    src_width = (len(tb.source)-1).bit_length()
+    src_mask = 2**src_width-1 if src_width else 0
+    src_shift = id_width-src_width
+    max_count = 2**src_shift
+    count_mask = max_count-1
 
     cur_id = 1
 
@@ -155,8 +173,6 @@ async def run_arb_test(dut):
 
     for k in range(5):
         test_frame = AxiStreamFrame(test_data, tx_complete=Event())
-        test_frame.tid = cur_id
-        test_frame.tdest = 0
 
         src_ind = 0
 
@@ -170,17 +186,21 @@ async def run_arb_test(dut):
         else:
             src_ind = 1
 
+        test_frame.tid = cur_id | (src_ind << src_shift)
+        test_frame.tdest = 0
+
         test_frames.append(test_frame)
         await tb.source[src_ind].send(test_frame)
 
-        cur_id = (cur_id + 1) % id_count
+        cur_id = (cur_id + 1) % max_count
 
     for k in [0, 1, 2, 4, 3]:
         test_frame = test_frames[k]
         rx_frame = await tb.sink[0].recv()
 
         assert rx_frame.tdata == test_frame.tdata
-        assert rx_frame.tid == test_frame.tid
+        assert (rx_frame.tid & id_mask) == test_frame.tid
+        assert ((rx_frame.tid >> src_shift) & src_mask) == (rx_frame.tid >> id_width)
         assert rx_frame.tdest == test_frame.tdest
         assert not rx_frame.tuser
 
@@ -195,7 +215,15 @@ async def run_stress_test(dut, idle_inserter=None, backpressure_inserter=None):
     tb = TB(dut)
 
     byte_lanes = tb.source[0].byte_lanes
-    id_count = 2**len(tb.source[0].bus.tid)
+    id_width = len(tb.source[0].bus.tid)
+    id_count = 2**id_width
+    id_mask = id_count-1
+
+    src_width = (len(tb.source)-1).bit_length()
+    src_mask = 2**src_width-1 if src_width else 0
+    src_shift = id_width-src_width
+    max_count = 2**src_shift
+    count_mask = max_count-1
 
     cur_id = 1
 
@@ -211,13 +239,13 @@ async def run_stress_test(dut, idle_inserter=None, backpressure_inserter=None):
             length = random.randint(1, byte_lanes*16)
             test_data = bytearray(itertools.islice(itertools.cycle(range(256)), length))
             test_frame = AxiStreamFrame(test_data)
-            test_frame.tid = cur_id
+            test_frame.tid = cur_id | (p << src_shift)
             test_frame.tdest = random.randrange(len(tb.sink))
 
             test_frames[p][test_frame.tdest].append(test_frame)
             await tb.source[p].send(test_frame)
 
-            cur_id = (cur_id + 1) % id_count
+            cur_id = (cur_id + 1) % max_count
 
     for lst in test_frames:
         while any(lst):
@@ -227,15 +255,15 @@ async def run_stress_test(dut, idle_inserter=None, backpressure_inserter=None):
 
             for lst_a in test_frames:
                 for lst_b in lst_a:
-                    if lst_b and lst_b[0].tid == rx_frame.tid:
+                    if lst_b and lst_b[0].tid == (rx_frame.tid & id_mask):
                         test_frame = lst_b.pop(0)
                         break
 
             assert test_frame is not None
 
             assert rx_frame.tdata == test_frame.tdata
-            assert rx_frame.tid == test_frame.tid
-            assert rx_frame.tdest == test_frame.tdest
+            assert (rx_frame.tid & id_mask) == test_frame.tid
+            assert ((rx_frame.tid >> src_shift) & src_mask) == (rx_frame.tid >> id_width)
             assert not rx_frame.tuser
 
     assert all(s.empty() for s in tb.sink)
@@ -334,6 +362,7 @@ def test_axis_switch(request, s_count, m_count, data_width):
     parameters['S_DEST_WIDTH'] = parameters['M_DEST_WIDTH'] + (m_count-1).bit_length()
     parameters['USER_ENABLE'] = 1
     parameters['USER_WIDTH'] = 1
+    parameters['UPDATE_TID'] = 1
     parameters['S_REG_TYPE'] = 0
     parameters['M_REG_TYPE'] = 2
     parameters['ARB_TYPE_ROUND_ROBIN'] = 1