diff --git a/lib/axis/rtl/axis_arb_mux.v b/lib/axis/rtl/axis_arb_mux.v
index 5f8eb2ddc6d16d8f53aad302167ef3c8b7ffa7b3..83bdd4f6bec5bb77bfccec8637cccddde1bc9816 100644
--- a/lib/axis/rtl/axis_arb_mux.v
+++ b/lib/axis/rtl/axis_arb_mux.v
@@ -43,8 +43,10 @@ module axis_arb_mux #
     parameter KEEP_WIDTH = (DATA_WIDTH/8),
     // Propagate tid signal
     parameter ID_ENABLE = 0,
-    // tid signal width
-    parameter ID_WIDTH = 8,
+    // input tid signal width
+    parameter S_ID_WIDTH = 8,
+    // output tid signal width
+    parameter M_ID_WIDTH = S_ID_WIDTH+$clog2(S_COUNT),
     // Propagate tdest signal
     parameter DEST_ENABLE = 0,
     // tdest signal width
@@ -55,6 +57,8 @@ module axis_arb_mux #
     parameter USER_WIDTH = 1,
     // Propagate tlast signal
     parameter LAST_ENABLE = 1,
+    // Update tid with routing information
+    parameter UPDATE_TID = 0,
     // select round robin arbitration
     parameter ARB_TYPE_ROUND_ROBIN = 0,
     // LSB priority selection
@@ -67,30 +71,47 @@ module axis_arb_mux #
     /*
      * AXI Stream inputs
      */
-    input  wire [S_COUNT*DATA_WIDTH-1:0] s_axis_tdata,
-    input  wire [S_COUNT*KEEP_WIDTH-1:0] s_axis_tkeep,
-    input  wire [S_COUNT-1:0]            s_axis_tvalid,
-    output wire [S_COUNT-1:0]            s_axis_tready,
-    input  wire [S_COUNT-1:0]            s_axis_tlast,
-    input  wire [S_COUNT*ID_WIDTH-1:0]   s_axis_tid,
-    input  wire [S_COUNT*DEST_WIDTH-1:0] s_axis_tdest,
-    input  wire [S_COUNT*USER_WIDTH-1:0] s_axis_tuser,
+    input  wire [S_COUNT*DATA_WIDTH-1:0]  s_axis_tdata,
+    input  wire [S_COUNT*KEEP_WIDTH-1:0]  s_axis_tkeep,
+    input  wire [S_COUNT-1:0]             s_axis_tvalid,
+    output wire [S_COUNT-1:0]             s_axis_tready,
+    input  wire [S_COUNT-1:0]             s_axis_tlast,
+    input  wire [S_COUNT*S_ID_WIDTH-1:0]  s_axis_tid,
+    input  wire [S_COUNT*DEST_WIDTH-1:0]  s_axis_tdest,
+    input  wire [S_COUNT*USER_WIDTH-1:0]  s_axis_tuser,
 
     /*
      * AXI Stream output
      */
-    output wire [DATA_WIDTH-1:0]         m_axis_tdata,
-    output wire [KEEP_WIDTH-1:0]         m_axis_tkeep,
-    output wire                          m_axis_tvalid,
-    input  wire                          m_axis_tready,
-    output wire                          m_axis_tlast,
-    output wire [ID_WIDTH-1:0]           m_axis_tid,
-    output wire [DEST_WIDTH-1:0]         m_axis_tdest,
-    output wire [USER_WIDTH-1:0]         m_axis_tuser
+    output wire [DATA_WIDTH-1:0]          m_axis_tdata,
+    output wire [KEEP_WIDTH-1:0]          m_axis_tkeep,
+    output wire                           m_axis_tvalid,
+    input  wire                           m_axis_tready,
+    output wire                           m_axis_tlast,
+    output wire [M_ID_WIDTH-1:0]          m_axis_tid,
+    output wire [DEST_WIDTH-1:0]          m_axis_tdest,
+    output wire [USER_WIDTH-1:0]          m_axis_tuser
 );
 
 parameter CL_S_COUNT = $clog2(S_COUNT);
 
+parameter S_ID_WIDTH_INT = S_ID_WIDTH > 0 ? S_ID_WIDTH : 1;
+
+// check configuration
+initial begin
+    if (UPDATE_TID) begin
+        if (!ID_ENABLE) begin
+            $error("Error: UPDATE_TID set requires ID_ENABLE set (instance %m)");
+            $finish;
+        end
+
+        if (M_ID_WIDTH < CL_S_COUNT) begin
+            $error("Error: M_ID_WIDTH too small for port count (instance %m)");
+            $finish;
+        end
+    end
+end
+
 wire [S_COUNT-1:0] request;
 wire [S_COUNT-1:0] acknowledge;
 wire [S_COUNT-1:0] grant;
@@ -103,7 +124,7 @@ reg  [KEEP_WIDTH-1:0] m_axis_tkeep_int;
 reg                   m_axis_tvalid_int;
 reg                   m_axis_tready_int_reg = 1'b0;
 reg                   m_axis_tlast_int;
-reg  [ID_WIDTH-1:0]   m_axis_tid_int;
+reg  [M_ID_WIDTH-1:0] m_axis_tid_int;
 reg  [DEST_WIDTH-1:0] m_axis_tdest_int;
 reg  [USER_WIDTH-1:0] m_axis_tuser_int;
 wire                  m_axis_tready_int_early;
@@ -116,7 +137,7 @@ wire [KEEP_WIDTH-1:0] current_s_tkeep  = s_axis_tkeep[grant_encoded*KEEP_WIDTH +
 wire                  current_s_tvalid = s_axis_tvalid[grant_encoded];
 wire                  current_s_tready = s_axis_tready[grant_encoded];
 wire                  current_s_tlast  = s_axis_tlast[grant_encoded];
-wire [ID_WIDTH-1:0]   current_s_tid    = s_axis_tid[grant_encoded*ID_WIDTH +: ID_WIDTH];
+wire [S_ID_WIDTH-1:0] current_s_tid    = s_axis_tid[grant_encoded*S_ID_WIDTH +: S_ID_WIDTH_INT];
 wire [DEST_WIDTH-1:0] current_s_tdest  = s_axis_tdest[grant_encoded*DEST_WIDTH +: DEST_WIDTH];
 wire [USER_WIDTH-1:0] current_s_tuser  = s_axis_tuser[grant_encoded*USER_WIDTH +: USER_WIDTH];
 
@@ -148,6 +169,9 @@ always @* begin
     m_axis_tvalid_int = current_s_tvalid && m_axis_tready_int_reg && grant_valid;
     m_axis_tlast_int  = current_s_tlast;
     m_axis_tid_int    = current_s_tid;
+    if (UPDATE_TID && S_COUNT > 1) begin
+        m_axis_tid_int[M_ID_WIDTH-1:M_ID_WIDTH-CL_S_COUNT] = grant_encoded;
+    end
     m_axis_tdest_int  = current_s_tdest;
     m_axis_tuser_int  = current_s_tuser;
 end
@@ -157,7 +181,7 @@ reg [DATA_WIDTH-1:0] m_axis_tdata_reg  = {DATA_WIDTH{1'b0}};
 reg [KEEP_WIDTH-1:0] m_axis_tkeep_reg  = {KEEP_WIDTH{1'b0}};
 reg                  m_axis_tvalid_reg = 1'b0, m_axis_tvalid_next;
 reg                  m_axis_tlast_reg  = 1'b0;
-reg [ID_WIDTH-1:0]   m_axis_tid_reg    = {ID_WIDTH{1'b0}};
+reg [M_ID_WIDTH-1:0] m_axis_tid_reg    = {M_ID_WIDTH{1'b0}};
 reg [DEST_WIDTH-1:0] m_axis_tdest_reg  = {DEST_WIDTH{1'b0}};
 reg [USER_WIDTH-1:0] m_axis_tuser_reg  = {USER_WIDTH{1'b0}};
 
@@ -165,7 +189,7 @@ reg [DATA_WIDTH-1:0] temp_m_axis_tdata_reg  = {DATA_WIDTH{1'b0}};
 reg [KEEP_WIDTH-1:0] temp_m_axis_tkeep_reg  = {KEEP_WIDTH{1'b0}};
 reg                  temp_m_axis_tvalid_reg = 1'b0, temp_m_axis_tvalid_next;
 reg                  temp_m_axis_tlast_reg  = 1'b0;
-reg [ID_WIDTH-1:0]   temp_m_axis_tid_reg    = {ID_WIDTH{1'b0}};
+reg [M_ID_WIDTH-1:0] temp_m_axis_tid_reg    = {M_ID_WIDTH{1'b0}};
 reg [DEST_WIDTH-1:0] temp_m_axis_tdest_reg  = {DEST_WIDTH{1'b0}};
 reg [USER_WIDTH-1:0] temp_m_axis_tuser_reg  = {USER_WIDTH{1'b0}};
 
@@ -178,7 +202,7 @@ assign m_axis_tdata  = m_axis_tdata_reg;
 assign m_axis_tkeep  = KEEP_ENABLE ? m_axis_tkeep_reg : {KEEP_WIDTH{1'b1}};
 assign m_axis_tvalid = m_axis_tvalid_reg;
 assign m_axis_tlast  = LAST_ENABLE ? m_axis_tlast_reg : 1'b1;
-assign m_axis_tid    = ID_ENABLE   ? m_axis_tid_reg   : {ID_WIDTH{1'b0}};
+assign m_axis_tid    = ID_ENABLE   ? m_axis_tid_reg   : {M_ID_WIDTH{1'b0}};
 assign m_axis_tdest  = DEST_ENABLE ? m_axis_tdest_reg : {DEST_WIDTH{1'b0}};
 assign m_axis_tuser  = USER_ENABLE ? m_axis_tuser_reg : {USER_WIDTH{1'b0}};
 
diff --git a/lib/axis/rtl/axis_arb_mux_wrap.py b/lib/axis/rtl/axis_arb_mux_wrap.py
index 4cfe2b8fc5aeddd68719b6b6df55fd92811233eb..091d36612846bd21fe8087c8f48b6b6e03a81e8e 100755
--- a/lib/axis/rtl/axis_arb_mux_wrap.py
+++ b/lib/axis/rtl/axis_arb_mux_wrap.py
@@ -78,8 +78,10 @@ module {{name}} #
     parameter KEEP_WIDTH = (DATA_WIDTH/8),
     // Propagate tid signal
     parameter ID_ENABLE = 0,
-    // tid signal width
-    parameter ID_WIDTH = 8,
+    // input tid signal width
+    parameter S_ID_WIDTH = 8,
+    // output tid signal width
+    parameter M_ID_WIDTH = S_ID_WIDTH+{{cn}},
     // Propagate tdest signal
     parameter DEST_ENABLE = 0,
     // tdest signal width
@@ -90,39 +92,41 @@ module {{name}} #
     parameter USER_WIDTH = 1,
     // Propagate tlast signal
     parameter LAST_ENABLE = 1,
+    // Update tid with routing information
+    parameter UPDATE_TID = 0,
     // select round robin arbitration
     parameter ARB_TYPE_ROUND_ROBIN = 0,
     // LSB priority selection
     parameter ARB_LSB_HIGH_PRIORITY = 1
 )
 (
-    input  wire                  clk,
-    input  wire                  rst,
+    input  wire                   clk,
+    input  wire                   rst,
 
     /*
      * AXI Stream inputs
      */
 {%- for p in range(n) %}
-    input  wire [DATA_WIDTH-1:0] s{{'%02d'%p}}_axis_tdata,
-    input  wire [KEEP_WIDTH-1:0] s{{'%02d'%p}}_axis_tkeep,
-    input  wire                  s{{'%02d'%p}}_axis_tvalid,
-    output wire                  s{{'%02d'%p}}_axis_tready,
-    input  wire                  s{{'%02d'%p}}_axis_tlast,
-    input  wire [ID_WIDTH-1:0]   s{{'%02d'%p}}_axis_tid,
-    input  wire [DEST_WIDTH-1:0] s{{'%02d'%p}}_axis_tdest,
-    input  wire [USER_WIDTH-1:0] s{{'%02d'%p}}_axis_tuser,
+    input  wire [DATA_WIDTH-1:0]  s{{'%02d'%p}}_axis_tdata,
+    input  wire [KEEP_WIDTH-1:0]  s{{'%02d'%p}}_axis_tkeep,
+    input  wire                   s{{'%02d'%p}}_axis_tvalid,
+    output wire                   s{{'%02d'%p}}_axis_tready,
+    input  wire                   s{{'%02d'%p}}_axis_tlast,
+    input  wire [S_ID_WIDTH-1:0]  s{{'%02d'%p}}_axis_tid,
+    input  wire [DEST_WIDTH-1:0]  s{{'%02d'%p}}_axis_tdest,
+    input  wire [USER_WIDTH-1:0]  s{{'%02d'%p}}_axis_tuser,
 {% endfor %}
     /*
      * AXI Stream output
      */
-    output wire [DATA_WIDTH-1:0] m_axis_tdata,
-    output wire [KEEP_WIDTH-1:0] m_axis_tkeep,
-    output wire                  m_axis_tvalid,
-    input  wire                  m_axis_tready,
-    output wire                  m_axis_tlast,
-    output wire [ID_WIDTH-1:0]   m_axis_tid,
-    output wire [DEST_WIDTH-1:0] m_axis_tdest,
-    output wire [USER_WIDTH-1:0] m_axis_tuser
+    output wire [DATA_WIDTH-1:0]  m_axis_tdata,
+    output wire [KEEP_WIDTH-1:0]  m_axis_tkeep,
+    output wire                   m_axis_tvalid,
+    input  wire                   m_axis_tready,
+    output wire                   m_axis_tlast,
+    output wire [M_ID_WIDTH-1:0]  m_axis_tid,
+    output wire [DEST_WIDTH-1:0]  m_axis_tdest,
+    output wire [USER_WIDTH-1:0]  m_axis_tuser
 );
 
 axis_arb_mux #(
@@ -131,12 +135,14 @@ axis_arb_mux #(
     .KEEP_ENABLE(KEEP_ENABLE),
     .KEEP_WIDTH(KEEP_WIDTH),
     .ID_ENABLE(ID_ENABLE),
-    .ID_WIDTH(ID_WIDTH),
+    .S_ID_WIDTH(S_ID_WIDTH),
+    .M_ID_WIDTH(M_ID_WIDTH),
     .DEST_ENABLE(DEST_ENABLE),
     .DEST_WIDTH(DEST_WIDTH),
     .USER_ENABLE(USER_ENABLE),
     .USER_WIDTH(USER_WIDTH),
     .LAST_ENABLE(LAST_ENABLE),
+    .UPDATE_TID(UPDATE_TID),
     .ARB_TYPE_ROUND_ROBIN(ARB_TYPE_ROUND_ROBIN),
     .ARB_LSB_HIGH_PRIORITY(ARB_LSB_HIGH_PRIORITY)
 )
diff --git a/lib/axis/rtl/axis_demux.v b/lib/axis/rtl/axis_demux.v
index 50b63359fadb58de83fe681b69c5bf39c5732bac..796799e552355cd67f5cad94fd45eaef93fe142f 100644
--- a/lib/axis/rtl/axis_demux.v
+++ b/lib/axis/rtl/axis_demux.v
@@ -47,51 +47,72 @@ module axis_demux #
     parameter ID_WIDTH = 8,
     // Propagate tdest signal
     parameter DEST_ENABLE = 0,
-    // tdest signal width
-    parameter DEST_WIDTH = 8,
+    // output tdest signal width
+    parameter M_DEST_WIDTH = 8,
+    // input tdest signal width
+    parameter S_DEST_WIDTH = M_DEST_WIDTH+$clog2(M_COUNT),
     // Propagate tuser signal
     parameter USER_ENABLE = 1,
     // tuser signal width
-    parameter USER_WIDTH = 1
+    parameter USER_WIDTH = 1,
+    // route via tdest
+    parameter TDEST_ROUTE = 0
 )
 (
-    input  wire                          clk,
-    input  wire                          rst,
+    input  wire                             clk,
+    input  wire                             rst,
 
     /*
      * AXI input
      */
-    input  wire [DATA_WIDTH-1:0]         s_axis_tdata,
-    input  wire [KEEP_WIDTH-1:0]         s_axis_tkeep,
-    input  wire                          s_axis_tvalid,
-    output wire                          s_axis_tready,
-    input  wire                          s_axis_tlast,
-    input  wire [ID_WIDTH-1:0]           s_axis_tid,
-    input  wire [DEST_WIDTH-1:0]         s_axis_tdest,
-    input  wire [USER_WIDTH-1:0]         s_axis_tuser,
+    input  wire [DATA_WIDTH-1:0]            s_axis_tdata,
+    input  wire [KEEP_WIDTH-1:0]            s_axis_tkeep,
+    input  wire                             s_axis_tvalid,
+    output wire                             s_axis_tready,
+    input  wire                             s_axis_tlast,
+    input  wire [ID_WIDTH-1:0]              s_axis_tid,
+    input  wire [S_DEST_WIDTH-1:0]          s_axis_tdest,
+    input  wire [USER_WIDTH-1:0]            s_axis_tuser,
 
     /*
      * AXI outputs
      */
-    output wire [M_COUNT*DATA_WIDTH-1:0] m_axis_tdata,
-    output wire [M_COUNT*KEEP_WIDTH-1:0] m_axis_tkeep,
-    output wire [M_COUNT-1:0]            m_axis_tvalid,
-    input  wire [M_COUNT-1:0]            m_axis_tready,
-    output wire [M_COUNT-1:0]            m_axis_tlast,
-    output wire [M_COUNT*ID_WIDTH-1:0]   m_axis_tid,
-    output wire [M_COUNT*DEST_WIDTH-1:0] m_axis_tdest,
-    output wire [M_COUNT*USER_WIDTH-1:0] m_axis_tuser,
+    output wire [M_COUNT*DATA_WIDTH-1:0]    m_axis_tdata,
+    output wire [M_COUNT*KEEP_WIDTH-1:0]    m_axis_tkeep,
+    output wire [M_COUNT-1:0]               m_axis_tvalid,
+    input  wire [M_COUNT-1:0]               m_axis_tready,
+    output wire [M_COUNT-1:0]               m_axis_tlast,
+    output wire [M_COUNT*ID_WIDTH-1:0]      m_axis_tid,
+    output wire [M_COUNT*M_DEST_WIDTH-1:0]  m_axis_tdest,
+    output wire [M_COUNT*USER_WIDTH-1:0]    m_axis_tuser,
 
     /*
      * Control
      */
-    input  wire                          enable,
-    input  wire                          drop,
-    input  wire [$clog2(M_COUNT)-1:0]    select
+    input  wire                             enable,
+    input  wire                             drop,
+    input  wire [$clog2(M_COUNT)-1:0]       select
 );
 
 parameter CL_M_COUNT = $clog2(M_COUNT);
 
+parameter M_DEST_WIDTH_INT = M_DEST_WIDTH > 0 ? M_DEST_WIDTH : 1;
+
+// check configuration
+initial begin
+    if (TDEST_ROUTE) begin
+        if (!DEST_ENABLE) begin
+            $error("Error: TDEST_ROUTE set requires DEST_ENABLE set (instance %m)");
+            $finish;
+        end
+
+        if (S_DEST_WIDTH < CL_M_COUNT) begin
+            $error("Error: S_DEST_WIDTH too small for port count (instance %m)");
+            $finish;
+        end
+    end
+end
+
 reg [CL_M_COUNT-1:0] select_reg = {CL_M_COUNT{1'b0}}, select_ctl, select_next;
 reg drop_reg = 1'b0, drop_ctl, drop_next;
 reg frame_reg = 1'b0, frame_ctl, frame_next;
@@ -99,15 +120,15 @@ reg frame_reg = 1'b0, frame_ctl, frame_next;
 reg s_axis_tready_reg = 1'b0, s_axis_tready_next;
 
 // internal datapath
-reg  [DATA_WIDTH-1:0] m_axis_tdata_int;
-reg  [KEEP_WIDTH-1:0] m_axis_tkeep_int;
-reg  [M_COUNT-1:0]    m_axis_tvalid_int;
-reg                   m_axis_tready_int_reg = 1'b0;
-reg                   m_axis_tlast_int;
-reg  [ID_WIDTH-1:0]   m_axis_tid_int;
-reg  [DEST_WIDTH-1:0] m_axis_tdest_int;
-reg  [USER_WIDTH-1:0] m_axis_tuser_int;
-wire                  m_axis_tready_int_early;
+reg  [DATA_WIDTH-1:0]    m_axis_tdata_int;
+reg  [KEEP_WIDTH-1:0]    m_axis_tkeep_int;
+reg  [M_COUNT-1:0]       m_axis_tvalid_int;
+reg                      m_axis_tready_int_reg = 1'b0;
+reg                      m_axis_tlast_int;
+reg  [ID_WIDTH-1:0]      m_axis_tid_int;
+reg  [M_DEST_WIDTH-1:0]  m_axis_tdest_int;
+reg  [USER_WIDTH-1:0]    m_axis_tuser_int;
+wire                     m_axis_tready_int_early;
 
 assign s_axis_tready = s_axis_tready_reg && enable;
 
@@ -131,8 +152,18 @@ always @* begin
 
     if (!frame_reg && s_axis_tvalid && s_axis_tready) begin
         // start of frame, grab select value
-        select_ctl = select;
-        drop_ctl = drop;
+        if (TDEST_ROUTE) begin
+            if (M_COUNT > 1) begin
+                select_ctl = s_axis_tdest[S_DEST_WIDTH-1:S_DEST_WIDTH-CL_M_COUNT];
+                drop_ctl = s_axis_tdest[S_DEST_WIDTH-1:S_DEST_WIDTH-CL_M_COUNT] >= M_COUNT;
+            end else begin
+                select_ctl = 0;
+                drop_ctl = 1'b0;
+            end
+        end else begin
+            select_ctl = select;
+            drop_ctl = drop || select >= M_COUNT;
+        end
         frame_ctl = 1'b1;
         if (!(s_axis_tready && s_axis_tvalid && s_axis_tlast)) begin
             select_next = select_ctl;
@@ -167,21 +198,21 @@ always @(posedge clk) begin
 end
 
 // output datapath logic
-reg [DATA_WIDTH-1:0] m_axis_tdata_reg  = {DATA_WIDTH{1'b0}};
-reg [KEEP_WIDTH-1:0] m_axis_tkeep_reg  = {KEEP_WIDTH{1'b0}};
-reg [M_COUNT-1:0]    m_axis_tvalid_reg = {M_COUNT{1'b0}}, m_axis_tvalid_next;
-reg                  m_axis_tlast_reg  = 1'b0;
-reg [ID_WIDTH-1:0]   m_axis_tid_reg    = {ID_WIDTH{1'b0}};
-reg [DEST_WIDTH-1:0] m_axis_tdest_reg  = {DEST_WIDTH{1'b0}};
-reg [USER_WIDTH-1:0] m_axis_tuser_reg  = {USER_WIDTH{1'b0}};
-
-reg [DATA_WIDTH-1:0] temp_m_axis_tdata_reg  = {DATA_WIDTH{1'b0}};
-reg [KEEP_WIDTH-1:0] temp_m_axis_tkeep_reg  = {KEEP_WIDTH{1'b0}};
-reg [M_COUNT-1:0]    temp_m_axis_tvalid_reg = {M_COUNT{1'b0}}, temp_m_axis_tvalid_next;
-reg                  temp_m_axis_tlast_reg  = 1'b0;
-reg [ID_WIDTH-1:0]   temp_m_axis_tid_reg    = {ID_WIDTH{1'b0}};
-reg [DEST_WIDTH-1:0] temp_m_axis_tdest_reg  = {DEST_WIDTH{1'b0}};
-reg [USER_WIDTH-1:0] temp_m_axis_tuser_reg  = {USER_WIDTH{1'b0}};
+reg [DATA_WIDTH-1:0]    m_axis_tdata_reg  = {DATA_WIDTH{1'b0}};
+reg [KEEP_WIDTH-1:0]    m_axis_tkeep_reg  = {KEEP_WIDTH{1'b0}};
+reg [M_COUNT-1:0]       m_axis_tvalid_reg = {M_COUNT{1'b0}}, m_axis_tvalid_next;
+reg                     m_axis_tlast_reg  = 1'b0;
+reg [ID_WIDTH-1:0]      m_axis_tid_reg    = {ID_WIDTH{1'b0}};
+reg [M_DEST_WIDTH-1:0]  m_axis_tdest_reg  = {M_DEST_WIDTH_INT{1'b0}};
+reg [USER_WIDTH-1:0]    m_axis_tuser_reg  = {USER_WIDTH{1'b0}};
+
+reg [DATA_WIDTH-1:0]    temp_m_axis_tdata_reg  = {DATA_WIDTH{1'b0}};
+reg [KEEP_WIDTH-1:0]    temp_m_axis_tkeep_reg  = {KEEP_WIDTH{1'b0}};
+reg [M_COUNT-1:0]       temp_m_axis_tvalid_reg = {M_COUNT{1'b0}}, temp_m_axis_tvalid_next;
+reg                     temp_m_axis_tlast_reg  = 1'b0;
+reg [ID_WIDTH-1:0]      temp_m_axis_tid_reg    = {ID_WIDTH{1'b0}};
+reg [M_DEST_WIDTH-1:0]  temp_m_axis_tdest_reg  = {M_DEST_WIDTH_INT{1'b0}};
+reg [USER_WIDTH-1:0]    temp_m_axis_tuser_reg  = {USER_WIDTH{1'b0}};
 
 // datapath control
 reg store_axis_int_to_output;
@@ -193,7 +224,7 @@ assign m_axis_tkeep  = KEEP_ENABLE ? {M_COUNT{m_axis_tkeep_reg}} : {M_COUNT*KEEP
 assign m_axis_tvalid = m_axis_tvalid_reg;
 assign m_axis_tlast  = {M_COUNT{m_axis_tlast_reg}};
 assign m_axis_tid    = ID_ENABLE   ? {M_COUNT{m_axis_tid_reg}}   : {M_COUNT*ID_WIDTH{1'b0}};
-assign m_axis_tdest  = DEST_ENABLE ? {M_COUNT{m_axis_tdest_reg}} : {M_COUNT*DEST_WIDTH{1'b0}};
+assign m_axis_tdest  = DEST_ENABLE ? {M_COUNT{m_axis_tdest_reg}} : {M_COUNT*M_DEST_WIDTH_INT{1'b0}};
 assign m_axis_tuser  = USER_ENABLE ? {M_COUNT{m_axis_tuser_reg}} : {M_COUNT*USER_WIDTH{1'b0}};
 
 // enable ready input next cycle if output is ready or the temp reg will not be filled on the next cycle (output reg empty or no input)
diff --git a/lib/axis/rtl/axis_demux_wrap.py b/lib/axis/rtl/axis_demux_wrap.py
index 8f69561db4b91825ab739416399ddbfe37a9691c..344c8bcb053654b1251c7559df08a6c4aea3f702 100755
--- a/lib/axis/rtl/axis_demux_wrap.py
+++ b/lib/axis/rtl/axis_demux_wrap.py
@@ -82,49 +82,52 @@ module {{name}} #
     parameter ID_WIDTH = 8,
     // Propagate tdest signal
     parameter DEST_ENABLE = 0,
-    // tdest signal width
-    parameter DEST_WIDTH = 8,
+    // output tdest signal width
+    parameter M_DEST_WIDTH = 1,
+    // input tdest signal width
+    parameter S_DEST_WIDTH = M_DEST_WIDTH+{{cn}},
     // Propagate tuser signal
     parameter USER_ENABLE = 1,
     // tuser signal width
-    parameter USER_WIDTH = 1
+    parameter USER_WIDTH = 1,
+    // route via tdest
+    parameter TDEST_ROUTE = 0
 )
 (
-    input  wire                  clk,
-    input  wire                  rst,
+    input  wire                     clk,
+    input  wire                     rst,
 
     /*
      * AXI Stream input
      */
-    input  wire [DATA_WIDTH-1:0] s_axis_tdata,
-    input  wire [KEEP_WIDTH-1:0] s_axis_tkeep,
-    input  wire                  s_axis_tvalid,
-    output wire                  s_axis_tready,
-    input  wire                  s_axis_tlast,
-    input  wire [ID_WIDTH-1:0]   s_axis_tid,
-    input  wire [DEST_WIDTH-1:0] s_axis_tdest,
-    input  wire [USER_WIDTH-1:0] s_axis_tuser,
+    input  wire [DATA_WIDTH-1:0]    s_axis_tdata,
+    input  wire [KEEP_WIDTH-1:0]    s_axis_tkeep,
+    input  wire                     s_axis_tvalid,
+    output wire                     s_axis_tready,
+    input  wire                     s_axis_tlast,
+    input  wire [ID_WIDTH-1:0]      s_axis_tid,
+    input  wire [S_DEST_WIDTH-1:0]  s_axis_tdest,
+    input  wire [USER_WIDTH-1:0]    s_axis_tuser,
 
     /*
      * AXI Stream outputs
      */
 {%- for p in range(n) %}
-    output wire [DATA_WIDTH-1:0] m{{'%02d'%p}}_axis_tdata,
-    output wire [KEEP_WIDTH-1:0] m{{'%02d'%p}}_axis_tkeep,
-    output wire                  m{{'%02d'%p}}_axis_tvalid,
-    input  wire                  m{{'%02d'%p}}_axis_tready,
-    output wire                  m{{'%02d'%p}}_axis_tlast,
-    output wire [ID_WIDTH-1:0]   m{{'%02d'%p}}_axis_tid,
-    output wire [DEST_WIDTH-1:0] m{{'%02d'%p}}_axis_tdest,
-    output wire [USER_WIDTH-1:0] m{{'%02d'%p}}_axis_tuser,
-{% endfor -%}
-
+    output wire [DATA_WIDTH-1:0]    m{{'%02d'%p}}_axis_tdata,
+    output wire [KEEP_WIDTH-1:0]    m{{'%02d'%p}}_axis_tkeep,
+    output wire                     m{{'%02d'%p}}_axis_tvalid,
+    input  wire                     m{{'%02d'%p}}_axis_tready,
+    output wire                     m{{'%02d'%p}}_axis_tlast,
+    output wire [ID_WIDTH-1:0]      m{{'%02d'%p}}_axis_tid,
+    output wire [M_DEST_WIDTH-1:0]  m{{'%02d'%p}}_axis_tdest,
+    output wire [USER_WIDTH-1:0]    m{{'%02d'%p}}_axis_tuser,
+{% endfor %}
     /*
      * Control
      */
-    input  wire                  enable,
-    input  wire                  drop,
-    input  wire [{{cn-1}}:0]            select
+    input  wire                     enable,
+    input  wire                     drop,
+    input  wire [{{cn-1}}:0]               select
 );
 
 axis_demux #(
@@ -135,9 +138,11 @@ axis_demux #(
     .ID_ENABLE(ID_ENABLE),
     .ID_WIDTH(ID_WIDTH),
     .DEST_ENABLE(DEST_ENABLE),
-    .DEST_WIDTH(DEST_WIDTH),
+    .S_DEST_WIDTH(S_DEST_WIDTH),
+    .M_DEST_WIDTH(M_DEST_WIDTH),
     .USER_ENABLE(USER_ENABLE),
-    .USER_WIDTH(USER_WIDTH)
+    .USER_WIDTH(USER_WIDTH),
+    .TDEST_ROUTE(TDEST_ROUTE)
 )
 axis_demux_inst (
     .clk(clk),
diff --git a/lib/axis/rtl/axis_ram_switch.v b/lib/axis/rtl/axis_ram_switch.v
index c936232c55fb13d11be96b88852628aa957e4953..ed831fe28435359750f45f5df7f714de42d7ae59 100644
--- a/lib/axis/rtl/axis_ram_switch.v
+++ b/lib/axis/rtl/axis_ram_switch.v
@@ -61,11 +61,15 @@ module axis_ram_switch #
     parameter M_KEEP_WIDTH = (M_DATA_WIDTH/8),
     // Propagate tid signal
     parameter ID_ENABLE = 0,
-    // tid signal width
-    parameter ID_WIDTH = 8,
-    // tdest signal width
+    // input tid signal width
+    parameter S_ID_WIDTH = 8,
+    // output tid signal width
+    parameter M_ID_WIDTH = S_ID_WIDTH+$clog2(S_COUNT),
+    // output tdest signal width
+    parameter M_DEST_WIDTH = 1,
+    // input tdest signal width
     // must be wide enough to uniquely address outputs
-    parameter DEST_WIDTH = $clog2(M_COUNT),
+    parameter S_DEST_WIDTH = M_DEST_WIDTH+$clog2(M_COUNT),
     // Propagate tuser signal
     parameter USER_ENABLE = 1,
     // tuser signal width
@@ -80,18 +84,20 @@ module axis_ram_switch #
     // When set, s_axis_tready is always asserted
     parameter DROP_WHEN_FULL = 0,
     // Output interface routing base tdest selection
-    // Concatenate M_COUNT DEST_WIDTH sized constants
+    // Concatenate M_COUNT S_DEST_WIDTH sized constants
     // Port selected if M_BASE <= tdest <= M_TOP
     // set to zero for default routing with tdest MSBs as port index
     parameter M_BASE = 0,
     // Output interface routing top tdest selection
-    // Concatenate M_COUNT DEST_WIDTH sized constants
+    // Concatenate M_COUNT S_DEST_WIDTH sized constants
     // Port selected if M_BASE <= tdest <= M_TOP
     // set to zero to inherit from M_BASE
     parameter M_TOP = 0,
     // Interface connection control
     // M_COUNT concatenated fields of S_COUNT bits
     parameter M_CONNECT = {M_COUNT{{S_COUNT{1'b1}}}},
+    // Update tid with routing information
+    parameter UPDATE_TID = 0,
     // select round robin arbitration
     parameter ARB_TYPE_ROUND_ROBIN = 1,
     // LSB priority selection
@@ -100,44 +106,47 @@ module axis_ram_switch #
     parameter RAM_PIPELINE = 2
 )
 (
-    input  wire                            clk,
-    input  wire                            rst,
+    input  wire                             clk,
+    input  wire                             rst,
 
     /*
      * AXI Stream inputs
      */
-    input  wire [S_COUNT*S_DATA_WIDTH-1:0] s_axis_tdata,
-    input  wire [S_COUNT*S_KEEP_WIDTH-1:0] s_axis_tkeep,
-    input  wire [S_COUNT-1:0]              s_axis_tvalid,
-    output wire [S_COUNT-1:0]              s_axis_tready,
-    input  wire [S_COUNT-1:0]              s_axis_tlast,
-    input  wire [S_COUNT*ID_WIDTH-1:0]     s_axis_tid,
-    input  wire [S_COUNT*DEST_WIDTH-1:0]   s_axis_tdest,
-    input  wire [S_COUNT*USER_WIDTH-1:0]   s_axis_tuser,
+    input  wire [S_COUNT*S_DATA_WIDTH-1:0]  s_axis_tdata,
+    input  wire [S_COUNT*S_KEEP_WIDTH-1:0]  s_axis_tkeep,
+    input  wire [S_COUNT-1:0]               s_axis_tvalid,
+    output wire [S_COUNT-1:0]               s_axis_tready,
+    input  wire [S_COUNT-1:0]               s_axis_tlast,
+    input  wire [S_COUNT*S_ID_WIDTH-1:0]    s_axis_tid,
+    input  wire [S_COUNT*S_DEST_WIDTH-1:0]  s_axis_tdest,
+    input  wire [S_COUNT*USER_WIDTH-1:0]    s_axis_tuser,
 
     /*
      * AXI Stream outputs
      */
-    output wire [M_COUNT*M_DATA_WIDTH-1:0] m_axis_tdata,
-    output wire [M_COUNT*M_KEEP_WIDTH-1:0] m_axis_tkeep,
-    output wire [M_COUNT-1:0]              m_axis_tvalid,
-    input  wire [M_COUNT-1:0]              m_axis_tready,
-    output wire [M_COUNT-1:0]              m_axis_tlast,
-    output wire [M_COUNT*ID_WIDTH-1:0]     m_axis_tid,
-    output wire [M_COUNT*DEST_WIDTH-1:0]   m_axis_tdest,
-    output wire [M_COUNT*USER_WIDTH-1:0]   m_axis_tuser,
+    output wire [M_COUNT*M_DATA_WIDTH-1:0]  m_axis_tdata,
+    output wire [M_COUNT*M_KEEP_WIDTH-1:0]  m_axis_tkeep,
+    output wire [M_COUNT-1:0]               m_axis_tvalid,
+    input  wire [M_COUNT-1:0]               m_axis_tready,
+    output wire [M_COUNT-1:0]               m_axis_tlast,
+    output wire [M_COUNT*M_ID_WIDTH-1:0]    m_axis_tid,
+    output wire [M_COUNT*M_DEST_WIDTH-1:0]  m_axis_tdest,
+    output wire [M_COUNT*USER_WIDTH-1:0]    m_axis_tuser,
 
     /*
      * Status
      */
-    output wire [S_COUNT-1:0]              status_overflow,
-    output wire [S_COUNT-1:0]              status_bad_frame,
-    output wire [S_COUNT-1:0]              status_good_frame
+    output wire [S_COUNT-1:0]               status_overflow,
+    output wire [S_COUNT-1:0]               status_bad_frame,
+    output wire [S_COUNT-1:0]               status_good_frame
 );
 
 parameter CL_S_COUNT = $clog2(S_COUNT);
 parameter CL_M_COUNT = $clog2(M_COUNT);
 
+parameter S_ID_WIDTH_INT = S_ID_WIDTH > 0 ? S_ID_WIDTH : 1;
+parameter M_DEST_WIDTH_INT = M_DEST_WIDTH > 0 ? M_DEST_WIDTH : 1;
+
 // force keep width to 1 when disabled
 parameter S_KEEP_WIDTH_INT = S_KEEP_ENABLE ? S_KEEP_WIDTH : 1;
 parameter M_KEEP_WIDTH_INT = M_KEEP_ENABLE ? M_KEEP_WIDTH : 1;
@@ -179,30 +188,42 @@ initial begin
         $finish;
     end
 
-    if (DEST_WIDTH < CL_M_COUNT) begin
-        $error("Error: DEST_WIDTH too small for port count (instance %m)");
+    if (S_DEST_WIDTH < CL_M_COUNT) begin
+        $error("Error: S_DEST_WIDTH too small for port count (instance %m)");
         $finish;
     end
 
+    if (UPDATE_TID) begin
+        if (!ID_ENABLE) begin
+            $error("Error: UPDATE_TID set requires ID_ENABLE set (instance %m)");
+            $finish;
+        end
+
+        if (M_ID_WIDTH < CL_S_COUNT) begin
+            $error("Error: M_ID_WIDTH too small for port count (instance %m)");
+            $finish;
+        end
+    end
+
     if (M_BASE == 0) begin
         // M_BASE is zero, route with tdest as port index
         $display("Addressing configuration for axis_switch instance %m");
         for (i = 0; i < M_COUNT; i = i + 1) begin
-            $display("%d: %08x-%08x (connect mask %b)", i, i << (DEST_WIDTH-CL_M_COUNT), ((i+1) << (DEST_WIDTH-CL_M_COUNT))-1, M_CONNECT[i*S_COUNT +: S_COUNT]);
+            $display("%d: %08x-%08x (connect mask %b)", i, i << (S_DEST_WIDTH-CL_M_COUNT), ((i+1) << (S_DEST_WIDTH-CL_M_COUNT))-1, M_CONNECT[i*S_COUNT +: S_COUNT]);
         end
 
     end else if (M_TOP == 0) begin
         // M_TOP is zero, assume equal to M_BASE
         $display("Addressing configuration for axis_switch instance %m");
         for (i = 0; i < M_COUNT; i = i + 1) begin
-            $display("%d: %08x (connect mask %b)", i, M_BASE[i*DEST_WIDTH +: DEST_WIDTH], M_CONNECT[i*S_COUNT +: S_COUNT]);
+            $display("%d: %08x (connect mask %b)", i, M_BASE[i*S_DEST_WIDTH +: S_DEST_WIDTH], M_CONNECT[i*S_COUNT +: S_COUNT]);
         end
 
         for (i = 0; i < M_COUNT; i = i + 1) begin
             for (j = i+1; j < M_COUNT; j = j + 1) begin
-                if (M_BASE[i*DEST_WIDTH +: DEST_WIDTH] == M_BASE[j*DEST_WIDTH +: DEST_WIDTH]) begin
-                    $display("%d: %08x", i, M_BASE[i*DEST_WIDTH +: DEST_WIDTH]);
-                    $display("%d: %08x", j, M_BASE[j*DEST_WIDTH +: DEST_WIDTH]);
+                if (M_BASE[i*S_DEST_WIDTH +: S_DEST_WIDTH] == M_BASE[j*S_DEST_WIDTH +: S_DEST_WIDTH]) begin
+                    $display("%d: %08x", i, M_BASE[i*S_DEST_WIDTH +: S_DEST_WIDTH]);
+                    $display("%d: %08x", j, M_BASE[j*S_DEST_WIDTH +: S_DEST_WIDTH]);
                     $error("Error: ranges overlap (instance %m)");
                     $finish;
                 end
@@ -211,11 +232,11 @@ initial begin
     end else begin
         $display("Addressing configuration for axis_switch instance %m");
         for (i = 0; i < M_COUNT; i = i + 1) begin
-            $display("%d: %08x-%08x (connect mask %b)", i, M_BASE[i*DEST_WIDTH +: DEST_WIDTH], M_TOP[i*DEST_WIDTH +: DEST_WIDTH], M_CONNECT[i*S_COUNT +: S_COUNT]);
+            $display("%d: %08x-%08x (connect mask %b)", i, M_BASE[i*S_DEST_WIDTH +: S_DEST_WIDTH], M_TOP[i*S_DEST_WIDTH +: S_DEST_WIDTH], M_CONNECT[i*S_COUNT +: S_COUNT]);
         end
 
         for (i = 0; i < M_COUNT; i = i + 1) begin
-            if (M_BASE[i*DEST_WIDTH +: DEST_WIDTH] > M_TOP[i*DEST_WIDTH +: DEST_WIDTH]) begin
+            if (M_BASE[i*S_DEST_WIDTH +: S_DEST_WIDTH] > M_TOP[i*S_DEST_WIDTH +: S_DEST_WIDTH]) begin
                 $error("Error: invalid range (instance %m)");
                 $finish;
             end
@@ -223,9 +244,9 @@ initial begin
 
         for (i = 0; i < M_COUNT; i = i + 1) begin
             for (j = i+1; j < M_COUNT; j = j + 1) begin
-                if (M_BASE[i*DEST_WIDTH +: DEST_WIDTH] <= M_TOP[j*DEST_WIDTH +: DEST_WIDTH] && M_BASE[j*DEST_WIDTH +: DEST_WIDTH] <= M_TOP[i*DEST_WIDTH +: DEST_WIDTH]) begin
-                    $display("%d: %08x-%08x", i, M_BASE[i*DEST_WIDTH +: DEST_WIDTH], M_TOP[i*DEST_WIDTH +: DEST_WIDTH]);
-                    $display("%d: %08x-%08x", j, M_BASE[j*DEST_WIDTH +: DEST_WIDTH], M_TOP[j*DEST_WIDTH +: DEST_WIDTH]);
+                if (M_BASE[i*S_DEST_WIDTH +: S_DEST_WIDTH] <= M_TOP[j*S_DEST_WIDTH +: S_DEST_WIDTH] && M_BASE[j*S_DEST_WIDTH +: S_DEST_WIDTH] <= M_TOP[i*S_DEST_WIDTH +: S_DEST_WIDTH]) begin
+                    $display("%d: %08x-%08x", i, M_BASE[i*S_DEST_WIDTH +: S_DEST_WIDTH], M_TOP[i*S_DEST_WIDTH +: S_DEST_WIDTH]);
+                    $display("%d: %08x-%08x", j, M_BASE[j*S_DEST_WIDTH +: S_DEST_WIDTH], M_TOP[j*S_DEST_WIDTH +: S_DEST_WIDTH]);
                     $error("Error: ranges overlap (instance %m)");
                     $finish;
                 end
@@ -349,21 +370,21 @@ always @(posedge clk) begin
 end
 
 // Interconnect
-wire [S_COUNT*RAM_ADDR_WIDTH-1:0] int_cmd_addr;
-wire [S_COUNT*ADDR_WIDTH-1:0]     int_cmd_len;
-wire [S_COUNT*CMD_ADDR_WIDTH-1:0] int_cmd_id;
-wire [S_COUNT*KEEP_WIDTH-1:0]     int_cmd_tkeep;
-wire [S_COUNT*ID_WIDTH-1:0]       int_cmd_tid;
-wire [S_COUNT*DEST_WIDTH-1:0]     int_cmd_tdest;
-wire [S_COUNT*USER_WIDTH-1:0]     int_cmd_tuser;
+wire [S_COUNT*RAM_ADDR_WIDTH-1:0]  int_cmd_addr;
+wire [S_COUNT*ADDR_WIDTH-1:0]      int_cmd_len;
+wire [S_COUNT*CMD_ADDR_WIDTH-1:0]  int_cmd_id;
+wire [S_COUNT*KEEP_WIDTH-1:0]      int_cmd_tkeep;
+wire [S_COUNT*S_ID_WIDTH-1:0]      int_cmd_tid;
+wire [S_COUNT*S_DEST_WIDTH-1:0]    int_cmd_tdest;
+wire [S_COUNT*USER_WIDTH-1:0]      int_cmd_tuser;
 
-wire [S_COUNT*M_COUNT-1:0]        int_cmd_valid;
-wire [M_COUNT*S_COUNT-1:0]        int_cmd_ready;
+wire [S_COUNT*M_COUNT-1:0]         int_cmd_valid;
+wire [M_COUNT*S_COUNT-1:0]         int_cmd_ready;
 
-wire [M_COUNT*CMD_ADDR_WIDTH-1:0] int_cmd_status_id;
+wire [M_COUNT*CMD_ADDR_WIDTH-1:0]  int_cmd_status_id;
 
-wire [M_COUNT*S_COUNT-1:0]        int_cmd_status_valid;
-wire [S_COUNT*M_COUNT-1:0]        int_cmd_status_ready;
+wire [M_COUNT*S_COUNT-1:0]         int_cmd_status_valid;
+wire [S_COUNT*M_COUNT-1:0]         int_cmd_status_ready;
 
 generate
 
@@ -371,14 +392,14 @@ generate
 
     for (m = 0; m < S_COUNT; m = m + 1) begin : s_ifaces
 
-        wire [DATA_WIDTH-1:0] port_axis_tdata;
-        wire [KEEP_WIDTH-1:0] port_axis_tkeep;
-        wire                  port_axis_tvalid;
-        wire                  port_axis_tready;
-        wire                  port_axis_tlast;
-        wire [ID_WIDTH-1:0]   port_axis_tid;
-        wire [DEST_WIDTH-1:0] port_axis_tdest;
-        wire [USER_WIDTH-1:0] port_axis_tuser;
+        wire [DATA_WIDTH-1:0]    port_axis_tdata;
+        wire [KEEP_WIDTH-1:0]    port_axis_tkeep;
+        wire                     port_axis_tvalid;
+        wire                     port_axis_tready;
+        wire                     port_axis_tlast;
+        wire [S_ID_WIDTH-1:0]    port_axis_tid;
+        wire [S_DEST_WIDTH-1:0]  port_axis_tdest;
+        wire [USER_WIDTH-1:0]    port_axis_tuser;
 
         axis_adapter #(
             .S_DATA_WIDTH(S_DATA_WIDTH),
@@ -387,10 +408,10 @@ generate
             .M_DATA_WIDTH(DATA_WIDTH),
             .M_KEEP_ENABLE(1),
             .M_KEEP_WIDTH(KEEP_WIDTH),
-            .ID_ENABLE(ID_ENABLE),
-            .ID_WIDTH(ID_WIDTH),
+            .ID_ENABLE(ID_ENABLE && S_ID_WIDTH > 0),
+            .ID_WIDTH(S_ID_WIDTH_INT),
             .DEST_ENABLE(1),
-            .DEST_WIDTH(DEST_WIDTH),
+            .DEST_WIDTH(S_DEST_WIDTH),
             .USER_ENABLE(USER_ENABLE),
             .USER_WIDTH(USER_WIDTH)
         )
@@ -403,8 +424,8 @@ generate
             .s_axis_tvalid(s_axis_tvalid[m]),
             .s_axis_tready(s_axis_tready[m]),
             .s_axis_tlast(s_axis_tlast[m]),
-            .s_axis_tid(s_axis_tid[ID_WIDTH*m +: ID_WIDTH]),
-            .s_axis_tdest(s_axis_tdest[DEST_WIDTH*m +: DEST_WIDTH]),
+            .s_axis_tid(s_axis_tid[S_ID_WIDTH*m +: S_ID_WIDTH_INT]),
+            .s_axis_tdest(s_axis_tdest[S_DEST_WIDTH*m +: S_DEST_WIDTH]),
             .s_axis_tuser(s_axis_tuser[USER_WIDTH*m +: USER_WIDTH]),
             // AXI output
             .m_axis_tdata(port_axis_tdata),
@@ -442,7 +463,7 @@ generate
                             drop_next = 1'b0;
                         end else begin
                             // M_BASE is zero, route with $clog2(M_COUNT) MSBs of tdest as port index
-                            if (port_axis_tdest[DEST_WIDTH-CL_M_COUNT +: CL_M_COUNT] == k && (M_CONNECT & (1 << (m+k*S_COUNT)))) begin
+                            if (port_axis_tdest[S_DEST_WIDTH-CL_M_COUNT +: CL_M_COUNT] == k && (M_CONNECT & (1 << (m+k*S_COUNT)))) begin
                                 select_next = k;
                                 select_valid_next = 1'b1;
                                 drop_next = 1'b0;
@@ -450,13 +471,13 @@ generate
                         end
                     end else if (M_TOP == 0) begin
                         // M_TOP is zero, assume equal to M_BASE
-                        if (port_axis_tdest == M_BASE[k*DEST_WIDTH +: DEST_WIDTH] && (M_CONNECT & (1 << (m+k*S_COUNT)))) begin
+                        if (port_axis_tdest == M_BASE[k*S_DEST_WIDTH +: S_DEST_WIDTH] && (M_CONNECT & (1 << (m+k*S_COUNT)))) begin
                             select_next = k;
                             select_valid_next = 1'b1;
                             drop_next = 1'b0;
                         end
                     end else begin
-                        if (port_axis_tdest >= M_BASE[k*DEST_WIDTH +: DEST_WIDTH] && port_axis_tdest <= M_TOP[k*DEST_WIDTH +: DEST_WIDTH] && (M_CONNECT & (1 << (m+k*S_COUNT)))) begin
+                        if (port_axis_tdest >= M_BASE[k*S_DEST_WIDTH +: S_DEST_WIDTH] && port_axis_tdest <= M_TOP[k*S_DEST_WIDTH +: S_DEST_WIDTH] && (M_CONNECT & (1 << (m+k*S_COUNT)))) begin
                             select_next = k;
                             select_valid_next = 1'b1;
                             drop_next = 1'b0;
@@ -542,8 +563,8 @@ generate
         reg [ADDR_WIDTH-1:0] cmd_table_len[2**CMD_ADDR_WIDTH-1:0];
         reg [CL_M_COUNT-1:0] cmd_table_select[2**CMD_ADDR_WIDTH-1:0];
         reg [KEEP_WIDTH-1:0] cmd_table_tkeep[2**CMD_ADDR_WIDTH-1:0];
-        reg [ID_WIDTH-1:0] cmd_table_tid[2**CMD_ADDR_WIDTH-1:0];
-        reg [DEST_WIDTH-1:0] cmd_table_tdest[2**CMD_ADDR_WIDTH-1:0];
+        reg [S_ID_WIDTH-1:0] cmd_table_tid[2**CMD_ADDR_WIDTH-1:0];
+        reg [S_DEST_WIDTH-1:0] cmd_table_tdest[2**CMD_ADDR_WIDTH-1:0];
         reg [USER_WIDTH-1:0] cmd_table_tuser[2**CMD_ADDR_WIDTH-1:0];
 
         reg [CMD_ADDR_WIDTH+1-1:0] cmd_table_start_ptr_reg = 0;
@@ -552,8 +573,8 @@ generate
         reg [ADDR_WIDTH-1:0] cmd_table_start_len;
         reg [CL_M_COUNT-1:0] cmd_table_start_select;
         reg [KEEP_WIDTH-1:0] cmd_table_start_tkeep;
-        reg [ID_WIDTH-1:0] cmd_table_start_tid;
-        reg [DEST_WIDTH-1:0] cmd_table_start_tdest;
+        reg [S_ID_WIDTH-1:0] cmd_table_start_tid;
+        reg [S_DEST_WIDTH-1:0] cmd_table_start_tdest;
         reg [USER_WIDTH-1:0] cmd_table_start_tuser;
         reg cmd_table_start_en;
         reg [CMD_ADDR_WIDTH+1-1:0] cmd_table_read_ptr_reg = 0;
@@ -563,14 +584,14 @@ generate
         reg [CMD_ADDR_WIDTH+1-1:0] cmd_table_finish_ptr_reg = 0;
         reg cmd_table_finish_en;
 
-        reg [RAM_ADDR_WIDTH-1:0] cmd_addr_reg = {RAM_ADDR_WIDTH{1'b0}}, cmd_addr_next;
-        reg [ADDR_WIDTH-1:0]     cmd_len_reg = {ADDR_WIDTH{1'b0}}, cmd_len_next;
-        reg [CMD_ADDR_WIDTH-1:0] cmd_id_reg = {CMD_ADDR_WIDTH{1'b0}}, cmd_id_next;
-        reg [KEEP_WIDTH-1:0]     cmd_tkeep_reg = {KEEP_WIDTH{1'b0}}, cmd_tkeep_next;
-        reg [ID_WIDTH-1:0]       cmd_tid_reg = {ID_WIDTH{1'b0}}, cmd_tid_next;
-        reg [DEST_WIDTH-1:0]     cmd_tdest_reg = {DEST_WIDTH{1'b0}}, cmd_tdest_next;
-        reg [USER_WIDTH-1:0]     cmd_tuser_reg = {USER_WIDTH{1'b0}}, cmd_tuser_next;
-        reg [M_COUNT-1:0]        cmd_valid_reg = 0, cmd_valid_next;
+        reg [RAM_ADDR_WIDTH-1:0]  cmd_addr_reg = {RAM_ADDR_WIDTH{1'b0}}, cmd_addr_next;
+        reg [ADDR_WIDTH-1:0]      cmd_len_reg = {ADDR_WIDTH{1'b0}}, cmd_len_next;
+        reg [CMD_ADDR_WIDTH-1:0]  cmd_id_reg = {CMD_ADDR_WIDTH{1'b0}}, cmd_id_next;
+        reg [KEEP_WIDTH-1:0]      cmd_tkeep_reg = {KEEP_WIDTH{1'b0}}, cmd_tkeep_next;
+        reg [S_ID_WIDTH-1:0]      cmd_tid_reg = {S_ID_WIDTH_INT{1'b0}}, cmd_tid_next;
+        reg [S_DEST_WIDTH-1:0]    cmd_tdest_reg = {S_DEST_WIDTH{1'b0}}, cmd_tdest_next;
+        reg [USER_WIDTH-1:0]      cmd_tuser_reg = {USER_WIDTH{1'b0}}, cmd_tuser_next;
+        reg [M_COUNT-1:0]         cmd_valid_reg = 0, cmd_valid_next;
 
         reg cmd_status_ready_reg = 1'b0, cmd_status_ready_next;
 
@@ -590,8 +611,8 @@ generate
         assign int_cmd_len[m*ADDR_WIDTH +: ADDR_WIDTH] = cmd_len_reg;
         assign int_cmd_id[m*CMD_ADDR_WIDTH +: CMD_ADDR_WIDTH] = cmd_id_reg;
         assign int_cmd_tkeep[m*KEEP_WIDTH +: KEEP_WIDTH] = cmd_tkeep_reg;
-        assign int_cmd_tid[m*ID_WIDTH +: ID_WIDTH] = cmd_tid_reg;
-        assign int_cmd_tdest[m*DEST_WIDTH +: DEST_WIDTH] = cmd_tdest_reg;
+        assign int_cmd_tid[m*S_ID_WIDTH +: S_ID_WIDTH_INT] = cmd_tid_reg;
+        assign int_cmd_tdest[m*S_DEST_WIDTH +: S_DEST_WIDTH] = cmd_tdest_reg;
         assign int_cmd_tuser[m*USER_WIDTH +: USER_WIDTH] = cmd_tuser_reg;
         assign int_cmd_valid[m*M_COUNT +: M_COUNT] = cmd_valid_reg;
 
@@ -821,16 +842,30 @@ generate
         );
 
         // mux
-        wire [RAM_ADDR_WIDTH-1:0] cmd_addr_mux  = int_cmd_addr[grant_encoded*RAM_ADDR_WIDTH +: RAM_ADDR_WIDTH];
-        wire [ADDR_WIDTH-1:0]     cmd_len_mux   = int_cmd_len[grant_encoded*ADDR_WIDTH +: ADDR_WIDTH];
-        wire [CMD_ADDR_WIDTH-1:0] cmd_id_mux    = int_cmd_id[grant_encoded*CMD_ADDR_WIDTH +: CMD_ADDR_WIDTH];
-        wire [KEEP_WIDTH-1:0]     cmd_tkeep_mux = int_cmd_tkeep[grant_encoded*KEEP_WIDTH +: KEEP_WIDTH];
-        wire [ID_WIDTH-1:0]       cmd_tid_mux   = int_cmd_tid[grant_encoded*ID_WIDTH +: ID_WIDTH];
-        wire [DEST_WIDTH-1:0]     cmd_tdest_mux = int_cmd_tdest[grant_encoded*DEST_WIDTH +: DEST_WIDTH];
-        wire [USER_WIDTH-1:0]     cmd_tuser_mux = int_cmd_tuser[grant_encoded*USER_WIDTH +: USER_WIDTH];
-        wire                      cmd_valid_mux = int_cmd_valid[grant_encoded*M_COUNT+n] && grant_valid;
+        reg  [RAM_ADDR_WIDTH-1:0] cmd_addr_mux;
+        reg  [ADDR_WIDTH-1:0]     cmd_len_mux;
+        reg  [CMD_ADDR_WIDTH-1:0] cmd_id_mux;
+        reg  [KEEP_WIDTH-1:0]     cmd_tkeep_mux;
+        reg  [M_ID_WIDTH-1:0]     cmd_tid_mux;
+        reg  [M_DEST_WIDTH-1:0]   cmd_tdest_mux;
+        reg  [USER_WIDTH-1:0]     cmd_tuser_mux;
+        reg                       cmd_valid_mux;
         wire                      cmd_ready_mux;
 
+        always @* begin
+            cmd_addr_mux  = int_cmd_addr[grant_encoded*RAM_ADDR_WIDTH +: RAM_ADDR_WIDTH];
+            cmd_len_mux   = int_cmd_len[grant_encoded*ADDR_WIDTH +: ADDR_WIDTH];
+            cmd_id_mux    = int_cmd_id[grant_encoded*CMD_ADDR_WIDTH +: CMD_ADDR_WIDTH];
+            cmd_tkeep_mux = int_cmd_tkeep[grant_encoded*KEEP_WIDTH +: KEEP_WIDTH];
+            cmd_tid_mux   = int_cmd_tid[grant_encoded*S_ID_WIDTH +: S_ID_WIDTH_INT];
+            if (UPDATE_TID && S_COUNT > 1) begin
+                cmd_tid_mux[M_ID_WIDTH-1:M_ID_WIDTH-CL_S_COUNT] = grant_encoded;
+            end
+            cmd_tdest_mux = int_cmd_tdest[grant_encoded*S_DEST_WIDTH +: S_DEST_WIDTH];
+            cmd_tuser_mux = int_cmd_tuser[grant_encoded*USER_WIDTH +: USER_WIDTH];
+            cmd_valid_mux = int_cmd_valid[grant_encoded*M_COUNT+n] && grant_valid;
+        end
+
         assign int_cmd_ready[n*S_COUNT +: S_COUNT] = (grant_valid && cmd_ready_mux) << grant_encoded;
 
         for (m = 0; m < S_COUNT; m = m + 1) begin
@@ -846,18 +881,18 @@ generate
         reg [CMD_ADDR_WIDTH-1:0] id_reg = 0, id_next;
 
         reg [KEEP_WIDTH-1:0] last_cycle_tkeep_reg = {KEEP_WIDTH{1'b0}}, last_cycle_tkeep_next;
-        reg [ID_WIDTH-1:0] tid_reg = {ID_WIDTH{1'b0}}, tid_next;
-        reg [DEST_WIDTH-1:0] tdest_reg = {DEST_WIDTH{1'b0}}, tdest_next;
+        reg [M_ID_WIDTH-1:0] tid_reg = {M_ID_WIDTH{1'b0}}, tid_next;
+        reg [M_DEST_WIDTH-1:0] tdest_reg = {M_DEST_WIDTH_INT{1'b0}}, tdest_next;
         reg [USER_WIDTH-1:0] tuser_reg = {USER_WIDTH{1'b0}}, tuser_next;
 
-        reg [DATA_WIDTH-1:0] out_axis_tdata_reg = {DATA_WIDTH{1'b0}}, out_axis_tdata_next;
-        reg [KEEP_WIDTH-1:0] out_axis_tkeep_reg = {KEEP_WIDTH{1'b0}}, out_axis_tkeep_next;
-        reg                  out_axis_tvalid_reg = 1'b0, out_axis_tvalid_next;
-        wire                 out_axis_tready;
-        reg                  out_axis_tlast_reg = 1'b0, out_axis_tlast_next;
-        reg [ID_WIDTH-1:0]   out_axis_tid_reg   = {ID_WIDTH{1'b0}}, out_axis_tid_next;
-        reg [DEST_WIDTH-1:0] out_axis_tdest_reg = {DEST_WIDTH{1'b0}}, out_axis_tdest_next;
-        reg [USER_WIDTH-1:0] out_axis_tuser_reg = {USER_WIDTH{1'b0}}, out_axis_tuser_next;
+        reg [DATA_WIDTH-1:0]    out_axis_tdata_reg = {DATA_WIDTH{1'b0}}, out_axis_tdata_next;
+        reg [KEEP_WIDTH-1:0]    out_axis_tkeep_reg = {KEEP_WIDTH{1'b0}}, out_axis_tkeep_next;
+        reg                     out_axis_tvalid_reg = 1'b0, out_axis_tvalid_next;
+        wire                    out_axis_tready;
+        reg                     out_axis_tlast_reg = 1'b0, out_axis_tlast_next;
+        reg [M_ID_WIDTH-1:0]    out_axis_tid_reg   = {M_ID_WIDTH{1'b0}}, out_axis_tid_next;
+        reg [M_DEST_WIDTH-1:0]  out_axis_tdest_reg = {M_DEST_WIDTH_INT{1'b0}}, out_axis_tdest_next;
+        reg [USER_WIDTH-1:0]    out_axis_tuser_reg = {USER_WIDTH{1'b0}}, out_axis_tuser_next;
 
         reg  [RAM_ADDR_WIDTH-1:0] ram_rd_addr_reg = {RAM_ADDR_WIDTH{1'b0}}, ram_rd_addr_next;
         reg                       ram_rd_en_reg = 1'b0, ram_rd_en_next;
@@ -878,8 +913,8 @@ generate
         reg [DATA_WIDTH-1:0] out_fifo_tdata[31:0];
         reg [KEEP_WIDTH-1:0] out_fifo_tkeep[31:0];
         reg out_fifo_tlast[31:0];
-        reg [ID_WIDTH-1:0] out_fifo_tid[31:0];
-        reg [DEST_WIDTH-1:0] out_fifo_tdest[31:0];
+        reg [M_ID_WIDTH-1:0] out_fifo_tid[31:0];
+        reg [M_DEST_WIDTH-1:0] out_fifo_tdest[31:0];
         reg [USER_WIDTH-1:0] out_fifo_tuser[31:0];
 
         reg [5:0] out_fifo_data_wr_ptr_reg = 0;
@@ -888,8 +923,8 @@ generate
         reg [5:0] out_fifo_ctrl_wr_ptr_reg = 0;
         reg [KEEP_WIDTH-1:0] out_fifo_ctrl_wr_tkeep;
         reg out_fifo_ctrl_wr_tlast;
-        reg [ID_WIDTH-1:0] out_fifo_ctrl_wr_tid;
-        reg [DEST_WIDTH-1:0] out_fifo_ctrl_wr_tdest;
+        reg [M_ID_WIDTH-1:0] out_fifo_ctrl_wr_tid;
+        reg [M_DEST_WIDTH-1:0] out_fifo_ctrl_wr_tdest;
         reg [USER_WIDTH-1:0] out_fifo_ctrl_wr_tuser;
         reg out_fifo_ctrl_wr_en;
         reg [5:0] out_fifo_rd_ptr_reg = 0;
@@ -1071,9 +1106,9 @@ generate
             .M_KEEP_ENABLE(M_KEEP_ENABLE),
             .M_KEEP_WIDTH(M_KEEP_WIDTH),
             .ID_ENABLE(ID_ENABLE),
-            .ID_WIDTH(ID_WIDTH),
-            .DEST_ENABLE(1),
-            .DEST_WIDTH(DEST_WIDTH),
+            .ID_WIDTH(M_ID_WIDTH),
+            .DEST_ENABLE(M_DEST_WIDTH > 0),
+            .DEST_WIDTH(M_DEST_WIDTH_INT),
             .USER_ENABLE(USER_ENABLE),
             .USER_WIDTH(USER_WIDTH)
         )
@@ -1095,8 +1130,8 @@ generate
             .m_axis_tvalid(m_axis_tvalid[n]),
             .m_axis_tready(m_axis_tready[n]),
             .m_axis_tlast(m_axis_tlast[n]),
-            .m_axis_tid(m_axis_tid[ID_WIDTH*n +: ID_WIDTH]),
-            .m_axis_tdest(m_axis_tdest[DEST_WIDTH*n +: DEST_WIDTH]),
+            .m_axis_tid(m_axis_tid[M_ID_WIDTH*n +: M_ID_WIDTH]),
+            .m_axis_tdest(m_axis_tdest[M_DEST_WIDTH*n +: M_DEST_WIDTH_INT]),
             .m_axis_tuser(m_axis_tuser[USER_WIDTH*n +: USER_WIDTH])
         );
     end // m_ifaces
diff --git a/lib/axis/rtl/axis_ram_switch_wrap.py b/lib/axis/rtl/axis_ram_switch_wrap.py
index 6e07113f41f42ac8d07f954232ec31655324f6d7..45adfed5bf389e29f435c06185dd2f109953f5a1 100755
--- a/lib/axis/rtl/axis_ram_switch_wrap.py
+++ b/lib/axis/rtl/axis_ram_switch_wrap.py
@@ -100,11 +100,15 @@ module {{name}} #
     parameter M_KEEP_WIDTH = (M_DATA_WIDTH/8),
     // Propagate tid signal
     parameter ID_ENABLE = 0,
-    // tid signal width
-    parameter ID_WIDTH = 8,
-    // tdest signal width
+    // input tid signal width
+    parameter S_ID_WIDTH = 8,
+    // output tid signal width
+    parameter M_ID_WIDTH = S_ID_WIDTH+{{cm}},
+    // output tdest signal width
+    parameter M_DEST_WIDTH = 1,
+    // input tdest signal width
     // must be wide enough to uniquely address outputs
-    parameter DEST_WIDTH = {{cn}},
+    parameter S_DEST_WIDTH = M_DEST_WIDTH+{{cn}},
     // Propagate tuser signal
     parameter USER_ENABLE = 1,
     // tuser signal width
@@ -128,6 +132,8 @@ module {{name}} #
     // Interface connection control
     parameter M{{'%02d'%p}}_CONNECT = {{m}}'b{% for p in range(m) %}1{% endfor %},
 {%- endfor %}
+    // Update tid with routing information
+    parameter UPDATE_TID = 0,
     // select round robin arbitration
     parameter ARB_TYPE_ROUND_ROBIN = 1,
     // LSB priority selection
@@ -143,27 +149,27 @@ module {{name}} #
      * AXI Stream inputs
      */
 {%- for p in range(m) %}
-    input  wire [S_DATA_WIDTH-1:0] s{{'%02d'%p}}_axis_tdata,
-    input  wire [S_KEEP_WIDTH-1:0] s{{'%02d'%p}}_axis_tkeep,
-    input  wire                    s{{'%02d'%p}}_axis_tvalid,
-    output wire                    s{{'%02d'%p}}_axis_tready,
-    input  wire                    s{{'%02d'%p}}_axis_tlast,
-    input  wire [ID_WIDTH-1:0]     s{{'%02d'%p}}_axis_tid,
-    input  wire [DEST_WIDTH-1:0]   s{{'%02d'%p}}_axis_tdest,
-    input  wire [USER_WIDTH-1:0]   s{{'%02d'%p}}_axis_tuser,
+    input  wire [S_DATA_WIDTH-1:0]  s{{'%02d'%p}}_axis_tdata,
+    input  wire [S_KEEP_WIDTH-1:0]  s{{'%02d'%p}}_axis_tkeep,
+    input  wire                     s{{'%02d'%p}}_axis_tvalid,
+    output wire                     s{{'%02d'%p}}_axis_tready,
+    input  wire                     s{{'%02d'%p}}_axis_tlast,
+    input  wire [S_ID_WIDTH-1:0]    s{{'%02d'%p}}_axis_tid,
+    input  wire [S_DEST_WIDTH-1:0]  s{{'%02d'%p}}_axis_tdest,
+    input  wire [USER_WIDTH-1:0]    s{{'%02d'%p}}_axis_tuser,
 {% endfor %}
     /*
      * AXI Stream outputs
      */
 {%- for p in range(n) %}
-    output wire [M_DATA_WIDTH-1:0] m{{'%02d'%p}}_axis_tdata,
-    output wire [M_KEEP_WIDTH-1:0] m{{'%02d'%p}}_axis_tkeep,
-    output wire                    m{{'%02d'%p}}_axis_tvalid,
-    input  wire                    m{{'%02d'%p}}_axis_tready,
-    output wire                    m{{'%02d'%p}}_axis_tlast,
-    output wire [ID_WIDTH-1:0]     m{{'%02d'%p}}_axis_tid,
-    output wire [DEST_WIDTH-1:0]   m{{'%02d'%p}}_axis_tdest,
-    output wire [USER_WIDTH-1:0]   m{{'%02d'%p}}_axis_tuser,
+    output wire [M_DATA_WIDTH-1:0]  m{{'%02d'%p}}_axis_tdata,
+    output wire [M_KEEP_WIDTH-1:0]  m{{'%02d'%p}}_axis_tkeep,
+    output wire                     m{{'%02d'%p}}_axis_tvalid,
+    input  wire                     m{{'%02d'%p}}_axis_tready,
+    output wire                     m{{'%02d'%p}}_axis_tlast,
+    output wire [M_ID_WIDTH-1:0]    m{{'%02d'%p}}_axis_tid,
+    output wire [M_DEST_WIDTH-1:0]  m{{'%02d'%p}}_axis_tdest,
+    output wire [USER_WIDTH-1:0]    m{{'%02d'%p}}_axis_tuser,
 {% endfor %}
     /*
      * Status
@@ -174,7 +180,7 @@ module {{name}} #
 );
 
 // parameter sizing helpers
-function [DEST_WIDTH-1:0] w_dw(input [DEST_WIDTH-1:0] val);
+function [S_DEST_WIDTH-1:0] w_dw(input [S_DEST_WIDTH-1:0] val);
     w_dw = val;
 endfunction
 
@@ -195,8 +201,10 @@ axis_ram_switch #(
     .M_KEEP_ENABLE(M_KEEP_ENABLE),
     .M_KEEP_WIDTH(M_KEEP_WIDTH),
     .ID_ENABLE(ID_ENABLE),
-    .ID_WIDTH(ID_WIDTH),
-    .DEST_WIDTH(DEST_WIDTH),
+    .S_ID_WIDTH(S_ID_WIDTH),
+    .M_ID_WIDTH(M_ID_WIDTH),
+    .S_DEST_WIDTH(S_DEST_WIDTH),
+    .M_DEST_WIDTH(M_DEST_WIDTH),
     .USER_ENABLE(USER_ENABLE),
     .USER_WIDTH(USER_WIDTH),
     .USER_BAD_FRAME_VALUE(USER_BAD_FRAME_VALUE),
@@ -206,6 +214,7 @@ axis_ram_switch #(
     .M_BASE({ {% for p in range(n-1,-1,-1) %}w_dw(M{{'%02d'%p}}_BASE){% if not loop.last %}, {% endif %}{% endfor %} }),
     .M_TOP({ {% for p in range(n-1,-1,-1) %}w_dw(M{{'%02d'%p}}_TOP){% if not loop.last %}, {% endif %}{% endfor %} }),
     .M_CONNECT({ {% for p in range(n-1,-1,-1) %}w_s(M{{'%02d'%p}}_CONNECT){% if not loop.last %}, {% endif %}{% endfor %} }),
+    .UPDATE_TID(UPDATE_TID),
     .ARB_TYPE_ROUND_ROBIN(ARB_TYPE_ROUND_ROBIN),
     .ARB_LSB_HIGH_PRIORITY(ARB_LSB_HIGH_PRIORITY),
     .RAM_PIPELINE(RAM_PIPELINE)
diff --git a/lib/axis/rtl/axis_switch.v b/lib/axis/rtl/axis_switch.v
index 4688d639ffa79d21348acedccb180d49ad5c2bbe..c949b700b618e60744512075b01042628bb026e4 100644
--- a/lib/axis/rtl/axis_switch.v
+++ b/lib/axis/rtl/axis_switch.v
@@ -45,28 +45,34 @@ module axis_switch #
     parameter KEEP_WIDTH = (DATA_WIDTH/8),
     // Propagate tid signal
     parameter ID_ENABLE = 0,
-    // tid signal width
-    parameter ID_WIDTH = 8,
-    // tdest signal width
+    // input tid signal width
+    parameter S_ID_WIDTH = 8,
+    // output tid signal width
+    parameter M_ID_WIDTH = S_ID_WIDTH+$clog2(S_COUNT),
+    // output tdest signal width
+    parameter M_DEST_WIDTH = 1,
+    // input tdest signal width
     // must be wide enough to uniquely address outputs
-    parameter DEST_WIDTH = $clog2(M_COUNT),
+    parameter S_DEST_WIDTH = M_DEST_WIDTH+$clog2(M_COUNT),
     // Propagate tuser signal
     parameter USER_ENABLE = 1,
     // tuser signal width
     parameter USER_WIDTH = 1,
     // Output interface routing base tdest selection
-    // Concatenate M_COUNT DEST_WIDTH sized constants
+    // Concatenate M_COUNT S_DEST_WIDTH sized constants
     // Port selected if M_BASE <= tdest <= M_TOP
     // set to zero for default routing with tdest MSBs as port index
     parameter M_BASE = 0,
     // Output interface routing top tdest selection
-    // Concatenate M_COUNT DEST_WIDTH sized constants
+    // Concatenate M_COUNT S_DEST_WIDTH sized constants
     // Port selected if M_BASE <= tdest <= M_TOP
     // set to zero to inherit from M_BASE
     parameter M_TOP = 0,
     // Interface connection control
     // M_COUNT concatenated fields of S_COUNT bits
     parameter M_CONNECT = {M_COUNT{{S_COUNT{1'b1}}}},
+    // Update tid with routing information
+    parameter UPDATE_TID = 0,
     // Input interface register type
     // 0 to bypass, 1 for simple buffer, 2 for skid buffer
     parameter S_REG_TYPE = 0,
@@ -79,65 +85,80 @@ module axis_switch #
     parameter ARB_LSB_HIGH_PRIORITY = 1
 )
 (
-    input  wire                          clk,
-    input  wire                          rst,
+    input  wire                             clk,
+    input  wire                             rst,
 
     /*
      * AXI Stream inputs
      */
-    input  wire [S_COUNT*DATA_WIDTH-1:0] s_axis_tdata,
-    input  wire [S_COUNT*KEEP_WIDTH-1:0] s_axis_tkeep,
-    input  wire [S_COUNT-1:0]            s_axis_tvalid,
-    output wire [S_COUNT-1:0]            s_axis_tready,
-    input  wire [S_COUNT-1:0]            s_axis_tlast,
-    input  wire [S_COUNT*ID_WIDTH-1:0]   s_axis_tid,
-    input  wire [S_COUNT*DEST_WIDTH-1:0] s_axis_tdest,
-    input  wire [S_COUNT*USER_WIDTH-1:0] s_axis_tuser,
+    input  wire [S_COUNT*DATA_WIDTH-1:0]    s_axis_tdata,
+    input  wire [S_COUNT*KEEP_WIDTH-1:0]    s_axis_tkeep,
+    input  wire [S_COUNT-1:0]               s_axis_tvalid,
+    output wire [S_COUNT-1:0]               s_axis_tready,
+    input  wire [S_COUNT-1:0]               s_axis_tlast,
+    input  wire [S_COUNT*S_ID_WIDTH-1:0]    s_axis_tid,
+    input  wire [S_COUNT*S_DEST_WIDTH-1:0]  s_axis_tdest,
+    input  wire [S_COUNT*USER_WIDTH-1:0]    s_axis_tuser,
 
     /*
      * AXI Stream outputs
      */
-    output wire [M_COUNT*DATA_WIDTH-1:0] m_axis_tdata,
-    output wire [M_COUNT*KEEP_WIDTH-1:0] m_axis_tkeep,
-    output wire [M_COUNT-1:0]            m_axis_tvalid,
-    input  wire [M_COUNT-1:0]            m_axis_tready,
-    output wire [M_COUNT-1:0]            m_axis_tlast,
-    output wire [M_COUNT*ID_WIDTH-1:0]   m_axis_tid,
-    output wire [M_COUNT*DEST_WIDTH-1:0] m_axis_tdest,
-    output wire [M_COUNT*USER_WIDTH-1:0] m_axis_tuser
+    output wire [M_COUNT*DATA_WIDTH-1:0]    m_axis_tdata,
+    output wire [M_COUNT*KEEP_WIDTH-1:0]    m_axis_tkeep,
+    output wire [M_COUNT-1:0]               m_axis_tvalid,
+    input  wire [M_COUNT-1:0]               m_axis_tready,
+    output wire [M_COUNT-1:0]               m_axis_tlast,
+    output wire [M_COUNT*M_ID_WIDTH-1:0]    m_axis_tid,
+    output wire [M_COUNT*M_DEST_WIDTH-1:0]  m_axis_tdest,
+    output wire [M_COUNT*USER_WIDTH-1:0]    m_axis_tuser
 );
 
 parameter CL_S_COUNT = $clog2(S_COUNT);
 parameter CL_M_COUNT = $clog2(M_COUNT);
 
+parameter S_ID_WIDTH_INT = S_ID_WIDTH > 0 ? S_ID_WIDTH : 1;
+parameter M_DEST_WIDTH_INT = M_DEST_WIDTH > 0 ? M_DEST_WIDTH : 1;
+
 integer i, j;
 
 // check configuration
 initial begin
-    if (DEST_WIDTH < CL_M_COUNT) begin
-        $error("Error: DEST_WIDTH too small for port count (instance %m)");
+    if (S_DEST_WIDTH < CL_M_COUNT) begin
+        $error("Error: S_DEST_WIDTH too small for port count (instance %m)");
         $finish;
     end
 
+    if (UPDATE_TID) begin
+        if (!ID_ENABLE) begin
+            $error("Error: UPDATE_TID set requires ID_ENABLE set (instance %m)");
+            $finish;
+        end
+
+        if (M_ID_WIDTH < CL_S_COUNT) begin
+            $error("Error: M_ID_WIDTH too small for port count (instance %m)");
+            $finish;
+        end
+    end
+
     if (M_BASE == 0) begin
         // M_BASE is zero, route with tdest as port index
         $display("Addressing configuration for axis_switch instance %m");
         for (i = 0; i < M_COUNT; i = i + 1) begin
-            $display("%d: %08x-%08x (connect mask %b)", i, i << (DEST_WIDTH-CL_M_COUNT), ((i+1) << (DEST_WIDTH-CL_M_COUNT))-1, M_CONNECT[i*S_COUNT +: S_COUNT]);
+            $display("%d: %08x-%08x (connect mask %b)", i, i << (S_DEST_WIDTH-CL_M_COUNT), ((i+1) << (S_DEST_WIDTH-CL_M_COUNT))-1, M_CONNECT[i*S_COUNT +: S_COUNT]);
         end
 
     end else if (M_TOP == 0) begin
         // M_TOP is zero, assume equal to M_BASE
         $display("Addressing configuration for axis_switch instance %m");
         for (i = 0; i < M_COUNT; i = i + 1) begin
-            $display("%d: %08x (connect mask %b)", i, M_BASE[i*DEST_WIDTH +: DEST_WIDTH], M_CONNECT[i*S_COUNT +: S_COUNT]);
+            $display("%d: %08x (connect mask %b)", i, M_BASE[i*S_DEST_WIDTH +: S_DEST_WIDTH], M_CONNECT[i*S_COUNT +: S_COUNT]);
         end
 
         for (i = 0; i < M_COUNT; i = i + 1) begin
             for (j = i+1; j < M_COUNT; j = j + 1) begin
-                if (M_BASE[i*DEST_WIDTH +: DEST_WIDTH] == M_BASE[j*DEST_WIDTH +: DEST_WIDTH]) begin
-                    $display("%d: %08x", i, M_BASE[i*DEST_WIDTH +: DEST_WIDTH]);
-                    $display("%d: %08x", j, M_BASE[j*DEST_WIDTH +: DEST_WIDTH]);
+                if (M_BASE[i*S_DEST_WIDTH +: S_DEST_WIDTH] == M_BASE[j*S_DEST_WIDTH +: S_DEST_WIDTH]) begin
+                    $display("%d: %08x", i, M_BASE[i*S_DEST_WIDTH +: S_DEST_WIDTH]);
+                    $display("%d: %08x", j, M_BASE[j*S_DEST_WIDTH +: S_DEST_WIDTH]);
                     $error("Error: ranges overlap (instance %m)");
                     $finish;
                 end
@@ -146,11 +167,11 @@ initial begin
     end else begin
         $display("Addressing configuration for axis_switch instance %m");
         for (i = 0; i < M_COUNT; i = i + 1) begin
-            $display("%d: %08x-%08x (connect mask %b)", i, M_BASE[i*DEST_WIDTH +: DEST_WIDTH], M_TOP[i*DEST_WIDTH +: DEST_WIDTH], M_CONNECT[i*S_COUNT +: S_COUNT]);
+            $display("%d: %08x-%08x (connect mask %b)", i, M_BASE[i*S_DEST_WIDTH +: S_DEST_WIDTH], M_TOP[i*S_DEST_WIDTH +: S_DEST_WIDTH], M_CONNECT[i*S_COUNT +: S_COUNT]);
         end
 
         for (i = 0; i < M_COUNT; i = i + 1) begin
-            if (M_BASE[i*DEST_WIDTH +: DEST_WIDTH] > M_TOP[i*DEST_WIDTH +: DEST_WIDTH]) begin
+            if (M_BASE[i*S_DEST_WIDTH +: S_DEST_WIDTH] > M_TOP[i*S_DEST_WIDTH +: S_DEST_WIDTH]) begin
                 $error("Error: invalid range (instance %m)");
                 $finish;
             end
@@ -158,9 +179,9 @@ initial begin
 
         for (i = 0; i < M_COUNT; i = i + 1) begin
             for (j = i+1; j < M_COUNT; j = j + 1) begin
-                if (M_BASE[i*DEST_WIDTH +: DEST_WIDTH] <= M_TOP[j*DEST_WIDTH +: DEST_WIDTH] && M_BASE[j*DEST_WIDTH +: DEST_WIDTH] <= M_TOP[i*DEST_WIDTH +: DEST_WIDTH]) begin
-                    $display("%d: %08x-%08x", i, M_BASE[i*DEST_WIDTH +: DEST_WIDTH], M_TOP[i*DEST_WIDTH +: DEST_WIDTH]);
-                    $display("%d: %08x-%08x", j, M_BASE[j*DEST_WIDTH +: DEST_WIDTH], M_TOP[j*DEST_WIDTH +: DEST_WIDTH]);
+                if (M_BASE[i*S_DEST_WIDTH +: S_DEST_WIDTH] <= M_TOP[j*S_DEST_WIDTH +: S_DEST_WIDTH] && M_BASE[j*S_DEST_WIDTH +: S_DEST_WIDTH] <= M_TOP[i*S_DEST_WIDTH +: S_DEST_WIDTH]) begin
+                    $display("%d: %08x-%08x", i, M_BASE[i*S_DEST_WIDTH +: S_DEST_WIDTH], M_TOP[i*S_DEST_WIDTH +: S_DEST_WIDTH]);
+                    $display("%d: %08x-%08x", j, M_BASE[j*S_DEST_WIDTH +: S_DEST_WIDTH], M_TOP[j*S_DEST_WIDTH +: S_DEST_WIDTH]);
                     $error("Error: ranges overlap (instance %m)");
                     $finish;
                 end
@@ -169,17 +190,17 @@ initial begin
     end
 end
 
-wire [S_COUNT*DATA_WIDTH-1:0] int_s_axis_tdata;
-wire [S_COUNT*KEEP_WIDTH-1:0] int_s_axis_tkeep;
-wire [S_COUNT-1:0]            int_s_axis_tvalid;
-wire [S_COUNT-1:0]            int_s_axis_tready;
-wire [S_COUNT-1:0]            int_s_axis_tlast;
-wire [S_COUNT*ID_WIDTH-1:0]   int_s_axis_tid;
-wire [S_COUNT*DEST_WIDTH-1:0] int_s_axis_tdest;
-wire [S_COUNT*USER_WIDTH-1:0] int_s_axis_tuser;
+wire [S_COUNT*DATA_WIDTH-1:0]    int_s_axis_tdata;
+wire [S_COUNT*KEEP_WIDTH-1:0]    int_s_axis_tkeep;
+wire [S_COUNT-1:0]               int_s_axis_tvalid;
+wire [S_COUNT-1:0]               int_s_axis_tready;
+wire [S_COUNT-1:0]               int_s_axis_tlast;
+wire [S_COUNT*S_ID_WIDTH-1:0]    int_s_axis_tid;
+wire [S_COUNT*S_DEST_WIDTH-1:0]  int_s_axis_tdest;
+wire [S_COUNT*USER_WIDTH-1:0]    int_s_axis_tuser;
 
-wire [S_COUNT*M_COUNT-1:0]    int_axis_tvalid;
-wire [M_COUNT*S_COUNT-1:0]    int_axis_tready;
+wire [S_COUNT*M_COUNT-1:0]       int_axis_tvalid;
+wire [M_COUNT*S_COUNT-1:0]       int_axis_tready;
 
 generate
 
@@ -212,7 +233,7 @@ generate
                             drop_next = 1'b0;
                         end else begin
                             // M_BASE is zero, route with $clog2(M_COUNT) MSBs of tdest as port index
-                            if (int_s_axis_tdest[m*DEST_WIDTH+(DEST_WIDTH-CL_M_COUNT) +: CL_M_COUNT] == k && (M_CONNECT & (1 << (m+k*S_COUNT)))) begin
+                            if (int_s_axis_tdest[m*S_DEST_WIDTH+(S_DEST_WIDTH-CL_M_COUNT) +: CL_M_COUNT] == k && (M_CONNECT & (1 << (m+k*S_COUNT)))) begin
                                 select_next = k;
                                 select_valid_next = 1'b1;
                                 drop_next = 1'b0;
@@ -220,13 +241,13 @@ generate
                         end
                     end else if (M_TOP == 0) begin
                         // M_TOP is zero, assume equal to M_BASE
-                        if (int_s_axis_tdest[m*DEST_WIDTH +: DEST_WIDTH] == M_BASE[k*DEST_WIDTH +: DEST_WIDTH] && (M_CONNECT & (1 << (m+k*S_COUNT)))) begin
+                        if (int_s_axis_tdest[m*S_DEST_WIDTH +: S_DEST_WIDTH] == M_BASE[k*S_DEST_WIDTH +: S_DEST_WIDTH] && (M_CONNECT & (1 << (m+k*S_COUNT)))) begin
                             select_next = k;
                             select_valid_next = 1'b1;
                             drop_next = 1'b0;
                         end
                     end else begin
-                        if (int_s_axis_tdest[m*DEST_WIDTH +: DEST_WIDTH] >= M_BASE[k*DEST_WIDTH +: DEST_WIDTH] && int_s_axis_tdest[m*DEST_WIDTH +: DEST_WIDTH] <= M_TOP[k*DEST_WIDTH +: DEST_WIDTH] && (M_CONNECT & (1 << (m+k*S_COUNT)))) begin
+                        if (int_s_axis_tdest[m*S_DEST_WIDTH +: S_DEST_WIDTH] >= M_BASE[k*S_DEST_WIDTH +: S_DEST_WIDTH] && int_s_axis_tdest[m*S_DEST_WIDTH +: S_DEST_WIDTH] <= M_TOP[k*S_DEST_WIDTH +: S_DEST_WIDTH] && (M_CONNECT & (1 << (m+k*S_COUNT)))) begin
                             select_next = k;
                             select_valid_next = 1'b1;
                             drop_next = 1'b0;
@@ -257,10 +278,10 @@ generate
             .KEEP_ENABLE(KEEP_ENABLE),
             .KEEP_WIDTH(KEEP_WIDTH),
             .LAST_ENABLE(1),
-            .ID_ENABLE(ID_ENABLE),
-            .ID_WIDTH(ID_WIDTH),
+            .ID_ENABLE(ID_ENABLE && S_ID_WIDTH > 0),
+            .ID_WIDTH(S_ID_WIDTH_INT),
             .DEST_ENABLE(1),
-            .DEST_WIDTH(DEST_WIDTH),
+            .DEST_WIDTH(S_DEST_WIDTH),
             .USER_ENABLE(USER_ENABLE),
             .USER_WIDTH(USER_WIDTH),
             .REG_TYPE(S_REG_TYPE)
@@ -274,8 +295,8 @@ generate
             .s_axis_tvalid(s_axis_tvalid[m]),
             .s_axis_tready(s_axis_tready[m]),
             .s_axis_tlast(s_axis_tlast[m]),
-            .s_axis_tid(s_axis_tid[m*ID_WIDTH +: ID_WIDTH]),
-            .s_axis_tdest(s_axis_tdest[m*DEST_WIDTH +: DEST_WIDTH]),
+            .s_axis_tid(s_axis_tid[m*S_ID_WIDTH +: S_ID_WIDTH_INT]),
+            .s_axis_tdest(s_axis_tdest[m*S_DEST_WIDTH +: S_DEST_WIDTH]),
             .s_axis_tuser(s_axis_tuser[m*USER_WIDTH +: USER_WIDTH]),
             // AXI output
             .m_axis_tdata(int_s_axis_tdata[m*DATA_WIDTH +: DATA_WIDTH]),
@@ -283,8 +304,8 @@ generate
             .m_axis_tvalid(int_s_axis_tvalid[m]),
             .m_axis_tready(int_s_axis_tready[m]),
             .m_axis_tlast(int_s_axis_tlast[m]),
-            .m_axis_tid(int_s_axis_tid[m*ID_WIDTH +: ID_WIDTH]),
-            .m_axis_tdest(int_s_axis_tdest[m*DEST_WIDTH +: DEST_WIDTH]),
+            .m_axis_tid(int_s_axis_tid[m*S_ID_WIDTH +: S_ID_WIDTH_INT]),
+            .m_axis_tdest(int_s_axis_tdest[m*S_DEST_WIDTH +: S_DEST_WIDTH]),
             .m_axis_tuser(int_s_axis_tuser[m*USER_WIDTH +: USER_WIDTH])
         );
     end // s_ifaces
@@ -316,20 +337,33 @@ generate
         );
 
         // mux
-        wire [DATA_WIDTH-1:0] s_axis_tdata_mux   = int_s_axis_tdata[grant_encoded*DATA_WIDTH +: DATA_WIDTH];
-        wire [KEEP_WIDTH-1:0] s_axis_tkeep_mux   = int_s_axis_tkeep[grant_encoded*KEEP_WIDTH +: KEEP_WIDTH];
-        wire                  s_axis_tvalid_mux  = int_axis_tvalid[grant_encoded*M_COUNT+n] && grant_valid;
-        wire                  s_axis_tready_mux;
-        wire                  s_axis_tlast_mux   = int_s_axis_tlast[grant_encoded];
-        wire [ID_WIDTH-1:0]   s_axis_tid_mux     = int_s_axis_tid[grant_encoded*ID_WIDTH +: ID_WIDTH];
-        wire [DEST_WIDTH-1:0] s_axis_tdest_mux   = int_s_axis_tdest[grant_encoded*DEST_WIDTH +: DEST_WIDTH];
-        wire [USER_WIDTH-1:0] s_axis_tuser_mux   = int_s_axis_tuser[grant_encoded*USER_WIDTH +: USER_WIDTH];
+        reg  [DATA_WIDTH-1:0]    m_axis_tdata_mux;
+        reg  [KEEP_WIDTH-1:0]    m_axis_tkeep_mux;
+        reg                      m_axis_tvalid_mux;
+        wire                     m_axis_tready_mux;
+        reg                      m_axis_tlast_mux;
+        reg  [M_ID_WIDTH-1:0]    m_axis_tid_mux;
+        reg  [M_DEST_WIDTH-1:0]  m_axis_tdest_mux;
+        reg  [USER_WIDTH-1:0]    m_axis_tuser_mux;
+
+        always @* begin
+            m_axis_tdata_mux   = int_s_axis_tdata[grant_encoded*DATA_WIDTH +: DATA_WIDTH];
+            m_axis_tkeep_mux   = int_s_axis_tkeep[grant_encoded*KEEP_WIDTH +: KEEP_WIDTH];
+            m_axis_tvalid_mux  = int_axis_tvalid[grant_encoded*M_COUNT+n] && grant_valid;
+            m_axis_tlast_mux   = int_s_axis_tlast[grant_encoded];
+            m_axis_tid_mux     = int_s_axis_tid[grant_encoded*S_ID_WIDTH +: S_ID_WIDTH_INT];
+            if (UPDATE_TID && S_COUNT > 1) begin
+                m_axis_tid_mux[M_ID_WIDTH-1:M_ID_WIDTH-CL_S_COUNT] = grant_encoded;
+            end
+            m_axis_tdest_mux   = int_s_axis_tdest[grant_encoded*S_DEST_WIDTH +: S_DEST_WIDTH];
+            m_axis_tuser_mux   = int_s_axis_tuser[grant_encoded*USER_WIDTH +: USER_WIDTH];
+        end
 
-        assign int_axis_tready[n*S_COUNT +: S_COUNT] = (grant_valid && s_axis_tready_mux) << grant_encoded;
+        assign int_axis_tready[n*S_COUNT +: S_COUNT] = (grant_valid && m_axis_tready_mux) << grant_encoded;
 
         for (m = 0; m < S_COUNT; m = m + 1) begin
             assign request[m] = int_axis_tvalid[m*M_COUNT+n] && !grant[m];
-            assign acknowledge[m] = grant[m] && int_axis_tvalid[m*M_COUNT+n] && s_axis_tlast_mux && s_axis_tready_mux;
+            assign acknowledge[m] = grant[m] && int_axis_tvalid[m*M_COUNT+n] && m_axis_tlast_mux && m_axis_tready_mux;
         end
 
         // M side register
@@ -339,9 +373,9 @@ generate
             .KEEP_WIDTH(KEEP_WIDTH),
             .LAST_ENABLE(1),
             .ID_ENABLE(ID_ENABLE),
-            .ID_WIDTH(ID_WIDTH),
-            .DEST_ENABLE(1),
-            .DEST_WIDTH(DEST_WIDTH),
+            .ID_WIDTH(M_ID_WIDTH),
+            .DEST_ENABLE(M_DEST_WIDTH > 0),
+            .DEST_WIDTH(M_DEST_WIDTH_INT),
             .USER_ENABLE(USER_ENABLE),
             .USER_WIDTH(USER_WIDTH),
             .REG_TYPE(M_REG_TYPE)
@@ -350,22 +384,22 @@ generate
             .clk(clk),
             .rst(rst),
             // AXI input
-            .s_axis_tdata(s_axis_tdata_mux),
-            .s_axis_tkeep(s_axis_tkeep_mux),
-            .s_axis_tvalid(s_axis_tvalid_mux),
-            .s_axis_tready(s_axis_tready_mux),
-            .s_axis_tlast(s_axis_tlast_mux),
-            .s_axis_tid(s_axis_tid_mux),
-            .s_axis_tdest(s_axis_tdest_mux),
-            .s_axis_tuser(s_axis_tuser_mux),
+            .s_axis_tdata(m_axis_tdata_mux),
+            .s_axis_tkeep(m_axis_tkeep_mux),
+            .s_axis_tvalid(m_axis_tvalid_mux),
+            .s_axis_tready(m_axis_tready_mux),
+            .s_axis_tlast(m_axis_tlast_mux),
+            .s_axis_tid(m_axis_tid_mux),
+            .s_axis_tdest(m_axis_tdest_mux),
+            .s_axis_tuser(m_axis_tuser_mux),
             // AXI output
             .m_axis_tdata(m_axis_tdata[n*DATA_WIDTH +: DATA_WIDTH]),
             .m_axis_tkeep(m_axis_tkeep[n*KEEP_WIDTH +: KEEP_WIDTH]),
             .m_axis_tvalid(m_axis_tvalid[n]),
             .m_axis_tready(m_axis_tready[n]),
             .m_axis_tlast(m_axis_tlast[n]),
-            .m_axis_tid(m_axis_tid[n*ID_WIDTH +: ID_WIDTH]),
-            .m_axis_tdest(m_axis_tdest[n*DEST_WIDTH +: DEST_WIDTH]),
+            .m_axis_tid(m_axis_tid[n*M_ID_WIDTH +: M_ID_WIDTH]),
+            .m_axis_tdest(m_axis_tdest[n*M_DEST_WIDTH +: M_DEST_WIDTH_INT]),
             .m_axis_tuser(m_axis_tuser[n*USER_WIDTH +: USER_WIDTH])
         );
     end // m_ifaces
diff --git a/lib/axis/rtl/axis_switch_wrap.py b/lib/axis/rtl/axis_switch_wrap.py
index d7d3bbfc6620c625acae55b8f50f2b7594a8130c..0af37e02acbfd44393c59679191b1fa78383afd4 100755
--- a/lib/axis/rtl/axis_switch_wrap.py
+++ b/lib/axis/rtl/axis_switch_wrap.py
@@ -84,11 +84,15 @@ module {{name}} #
     parameter KEEP_WIDTH = (DATA_WIDTH/8),
     // Propagate tid signal
     parameter ID_ENABLE = 0,
-    // tid signal width
-    parameter ID_WIDTH = 8,
-    // tdest signal width
+    // input tid signal width
+    parameter S_ID_WIDTH = 8,
+    // output tid signal width
+    parameter M_ID_WIDTH = S_ID_WIDTH+{{cm}},
+    // output tdest signal width
+    parameter M_DEST_WIDTH = 1,
+    // input tdest signal width
     // must be wide enough to uniquely address outputs
-    parameter DEST_WIDTH = {{cn}},
+    parameter S_DEST_WIDTH = M_DEST_WIDTH+{{cn}},
     // Propagate tuser signal
     parameter USER_ENABLE = 1,
     // tuser signal width
@@ -103,6 +107,8 @@ module {{name}} #
     // Interface connection control
     parameter M{{'%02d'%p}}_CONNECT = {{m}}'b{% for p in range(m) %}1{% endfor %},
 {%- endfor %}
+    // Update tid with routing information
+    parameter UPDATE_TID = 0,
     // Input interface register type
     // 0 to bypass, 1 for simple buffer, 2 for skid buffer
     parameter S_REG_TYPE = 0,
@@ -122,32 +128,32 @@ module {{name}} #
      * AXI Stream inputs
      */
 {%- for p in range(m) %}
-    input  wire [DATA_WIDTH-1:0] s{{'%02d'%p}}_axis_tdata,
-    input  wire [KEEP_WIDTH-1:0] s{{'%02d'%p}}_axis_tkeep,
-    input  wire                  s{{'%02d'%p}}_axis_tvalid,
-    output wire                  s{{'%02d'%p}}_axis_tready,
-    input  wire                  s{{'%02d'%p}}_axis_tlast,
-    input  wire [ID_WIDTH-1:0]   s{{'%02d'%p}}_axis_tid,
-    input  wire [DEST_WIDTH-1:0] s{{'%02d'%p}}_axis_tdest,
-    input  wire [USER_WIDTH-1:0] s{{'%02d'%p}}_axis_tuser,
+    input  wire [DATA_WIDTH-1:0]    s{{'%02d'%p}}_axis_tdata,
+    input  wire [KEEP_WIDTH-1:0]    s{{'%02d'%p}}_axis_tkeep,
+    input  wire                     s{{'%02d'%p}}_axis_tvalid,
+    output wire                     s{{'%02d'%p}}_axis_tready,
+    input  wire                     s{{'%02d'%p}}_axis_tlast,
+    input  wire [S_ID_WIDTH-1:0]    s{{'%02d'%p}}_axis_tid,
+    input  wire [S_DEST_WIDTH-1:0]  s{{'%02d'%p}}_axis_tdest,
+    input  wire [USER_WIDTH-1:0]    s{{'%02d'%p}}_axis_tuser,
 {% endfor %}
     /*
      * AXI Stream outputs
      */
 {%- for p in range(n) %}
-    output wire [DATA_WIDTH-1:0] m{{'%02d'%p}}_axis_tdata,
-    output wire [KEEP_WIDTH-1:0] m{{'%02d'%p}}_axis_tkeep,
-    output wire                  m{{'%02d'%p}}_axis_tvalid,
-    input  wire                  m{{'%02d'%p}}_axis_tready,
-    output wire                  m{{'%02d'%p}}_axis_tlast,
-    output wire [ID_WIDTH-1:0]   m{{'%02d'%p}}_axis_tid,
-    output wire [DEST_WIDTH-1:0] m{{'%02d'%p}}_axis_tdest,
-    output wire [USER_WIDTH-1:0] m{{'%02d'%p}}_axis_tuser{% if not loop.last %},{% endif %}
+    output wire [DATA_WIDTH-1:0]    m{{'%02d'%p}}_axis_tdata,
+    output wire [KEEP_WIDTH-1:0]    m{{'%02d'%p}}_axis_tkeep,
+    output wire                     m{{'%02d'%p}}_axis_tvalid,
+    input  wire                     m{{'%02d'%p}}_axis_tready,
+    output wire                     m{{'%02d'%p}}_axis_tlast,
+    output wire [M_ID_WIDTH-1:0]    m{{'%02d'%p}}_axis_tid,
+    output wire [M_DEST_WIDTH-1:0]  m{{'%02d'%p}}_axis_tdest,
+    output wire [USER_WIDTH-1:0]    m{{'%02d'%p}}_axis_tuser{% if not loop.last %},{% endif %}
 {% endfor -%}
 );
 
 // parameter sizing helpers
-function [DEST_WIDTH-1:0] w_dw(input [DEST_WIDTH-1:0] val);
+function [S_DEST_WIDTH-1:0] w_dw(input [S_DEST_WIDTH-1:0] val);
     w_dw = val;
 endfunction
 
@@ -162,13 +168,16 @@ axis_switch #(
     .KEEP_ENABLE(KEEP_ENABLE),
     .KEEP_WIDTH(KEEP_WIDTH),
     .ID_ENABLE(ID_ENABLE),
-    .ID_WIDTH(ID_WIDTH),
-    .DEST_WIDTH(DEST_WIDTH),
+    .S_ID_WIDTH(S_ID_WIDTH),
+    .M_ID_WIDTH(M_ID_WIDTH),
+    .S_DEST_WIDTH(S_DEST_WIDTH),
+    .M_DEST_WIDTH(M_DEST_WIDTH),
     .USER_ENABLE(USER_ENABLE),
     .USER_WIDTH(USER_WIDTH),
     .M_BASE({ {% for p in range(n-1,-1,-1) %}w_dw(M{{'%02d'%p}}_BASE){% if not loop.last %}, {% endif %}{% endfor %} }),
     .M_TOP({ {% for p in range(n-1,-1,-1) %}w_dw(M{{'%02d'%p}}_TOP){% if not loop.last %}, {% endif %}{% endfor %} }),
     .M_CONNECT({ {% for p in range(n-1,-1,-1) %}w_s(M{{'%02d'%p}}_CONNECT){% if not loop.last %}, {% endif %}{% endfor %} }),
+    .UPDATE_TID(UPDATE_TID),
     .S_REG_TYPE(S_REG_TYPE),
     .M_REG_TYPE(M_REG_TYPE),
     .ARB_TYPE_ROUND_ROBIN(ARB_TYPE_ROUND_ROBIN),
diff --git a/lib/axis/tb/axis_adapter/test_axis_adapter.py b/lib/axis/tb/axis_adapter/test_axis_adapter.py
index 11a2268bf35f0d4964fda141e6042e10ab81e356..486ac879c450db6cc4b9cf706fcebf9d70899ce4 100644
--- a/lib/axis/tb/axis_adapter/test_axis_adapter.py
+++ b/lib/axis/tb/axis_adapter/test_axis_adapter.py
@@ -63,10 +63,10 @@ class TB(object):
         self.dut.rst.setimmediatevalue(0)
         await RisingEdge(self.dut.clk)
         await RisingEdge(self.dut.clk)
-        self.dut.rst <= 1
+        self.dut.rst.value = 1
         await RisingEdge(self.dut.clk)
         await RisingEdge(self.dut.clk)
-        self.dut.rst <= 0
+        self.dut.rst.value = 0
         await RisingEdge(self.dut.clk)
         await RisingEdge(self.dut.clk)
 
diff --git a/lib/axis/tb/axis_arb_mux/Makefile b/lib/axis/tb/axis_arb_mux/Makefile
index 3e3dc20a6feaa63e22faac2d8740ff5124568abf..962af950fff112fe5abc2a58d73bb66fd484e15a 100644
--- a/lib/axis/tb/axis_arb_mux/Makefile
+++ b/lib/axis/tb/axis_arb_mux/Makefile
@@ -26,10 +26,10 @@ WAVES ?= 0
 COCOTB_HDL_TIMEUNIT = 1ns
 COCOTB_HDL_TIMEPRECISION = 1ps
 
-export PARAM_PORTS ?= 4
+export PORTS ?= 4
 
 DUT      = axis_arb_mux
-WRAPPER  = $(DUT)_wrap_$(PARAM_PORTS)
+WRAPPER  = $(DUT)_wrap_$(PORTS)
 TOPLEVEL = $(WRAPPER)
 MODULE   = test_$(DUT)
 VERILOG_SOURCES += $(WRAPPER).v
@@ -42,12 +42,14 @@ export PARAM_DATA_WIDTH ?= 8
 export PARAM_KEEP_ENABLE ?= $(shell expr $(PARAM_DATA_WIDTH) \> 8 )
 export PARAM_KEEP_WIDTH ?= $(shell expr $(PARAM_DATA_WIDTH) / 8 )
 export PARAM_ID_ENABLE ?= 1
-export PARAM_ID_WIDTH ?= 8
+export PARAM_S_ID_WIDTH ?= 8
+export PARAM_M_ID_WIDTH ?= $(shell python -c "print($(PARAM_S_ID_WIDTH) + ($(PORTS)-1).bit_length())")
 export PARAM_DEST_ENABLE ?= 1
 export PARAM_DEST_WIDTH ?= 8
 export PARAM_USER_ENABLE ?= 1
 export PARAM_USER_WIDTH ?= 1
 export PARAM_LAST_ENABLE ?= 1
+export PARAM_UPDATE_TID ?= 1
 export PARAM_ARB_TYPE_ROUND_ROBIN ?= 0
 export PARAM_ARB_LSB_HIGH_PRIORITY ?= 1
 
@@ -58,12 +60,14 @@ ifeq ($(SIM), icarus)
 	COMPILE_ARGS += -P $(TOPLEVEL).KEEP_ENABLE=$(PARAM_KEEP_ENABLE)
 	COMPILE_ARGS += -P $(TOPLEVEL).KEEP_WIDTH=$(PARAM_KEEP_WIDTH)
 	COMPILE_ARGS += -P $(TOPLEVEL).ID_ENABLE=$(PARAM_ID_ENABLE)
-	COMPILE_ARGS += -P $(TOPLEVEL).ID_WIDTH=$(PARAM_ID_WIDTH)
+	COMPILE_ARGS += -P $(TOPLEVEL).M_ID_WIDTH=$(PARAM_M_ID_WIDTH)
+	COMPILE_ARGS += -P $(TOPLEVEL).S_ID_WIDTH=$(PARAM_S_ID_WIDTH)
 	COMPILE_ARGS += -P $(TOPLEVEL).DEST_ENABLE=$(PARAM_DEST_ENABLE)
 	COMPILE_ARGS += -P $(TOPLEVEL).DEST_WIDTH=$(PARAM_DEST_WIDTH)
 	COMPILE_ARGS += -P $(TOPLEVEL).USER_ENABLE=$(PARAM_USER_ENABLE)
 	COMPILE_ARGS += -P $(TOPLEVEL).USER_WIDTH=$(PARAM_USER_WIDTH)
 	COMPILE_ARGS += -P $(TOPLEVEL).LAST_ENABLE=$(PARAM_LAST_ENABLE)
+	COMPILE_ARGS += -P $(TOPLEVEL).UPDATE_TID=$(PARAM_UPDATE_TID)
 	COMPILE_ARGS += -P $(TOPLEVEL).ARB_TYPE_ROUND_ROBIN=$(PARAM_ARB_TYPE_ROUND_ROBIN)
 	COMPILE_ARGS += -P $(TOPLEVEL).ARB_LSB_HIGH_PRIORITY=$(PARAM_ARB_LSB_HIGH_PRIORITY)
 
@@ -78,12 +82,14 @@ else ifeq ($(SIM), verilator)
 	COMPILE_ARGS += -GKEEP_ENABLE=$(PARAM_KEEP_ENABLE)
 	COMPILE_ARGS += -GKEEP_WIDTH=$(PARAM_KEEP_WIDTH)
 	COMPILE_ARGS += -GID_ENABLE=$(PARAM_ID_ENABLE)
-	COMPILE_ARGS += -GID_WIDTH=$(PARAM_ID_WIDTH)
+	COMPILE_ARGS += -GM_ID_WIDTH=$(PARAM_M_ID_WIDTH)
+	COMPILE_ARGS += -GS_ID_WIDTH=$(PARAM_S_ID_WIDTH)
 	COMPILE_ARGS += -GDEST_ENABLE=$(PARAM_DEST_ENABLE)
 	COMPILE_ARGS += -GDEST_WIDTH=$(PARAM_DEST_WIDTH)
 	COMPILE_ARGS += -GUSER_ENABLE=$(PARAM_USER_ENABLE)
 	COMPILE_ARGS += -GUSER_WIDTH=$(PARAM_USER_WIDTH)
 	COMPILE_ARGS += -GLAST_ENABLE=$(PARAM_LAST_ENABLE)
+	COMPILE_ARGS += -GUPDATE_TID=$(PARAM_UPDATE_TID)
 	COMPILE_ARGS += -GARB_TYPE_ROUND_ROBIN=$(PARAM_ARB_TYPE_ROUND_ROBIN)
 	COMPILE_ARGS += -GARB_LSB_HIGH_PRIORITY=$(PARAM_ARB_LSB_HIGH_PRIORITY)
 
@@ -95,7 +101,7 @@ endif
 include $(shell cocotb-config --makefiles)/Makefile.sim
 
 $(WRAPPER).v: ../../rtl/$(DUT)_wrap.py
-	$< -p $(PARAM_PORTS)
+	$< -p $(PORTS)
 
 iverilog_dump.v:
 	echo 'module iverilog_dump();' > $@
diff --git a/lib/axis/tb/axis_arb_mux/test_axis_arb_mux.py b/lib/axis/tb/axis_arb_mux/test_axis_arb_mux.py
index 3483a6e1eae3ccbc2dcaff0b6ae4df333a48f099..dfe264b94ff432ced5bcfad6ba0f04ce10116cb1 100644
--- a/lib/axis/tb/axis_arb_mux/test_axis_arb_mux.py
+++ b/lib/axis/tb/axis_arb_mux/test_axis_arb_mux.py
@@ -67,10 +67,10 @@ class TB(object):
         self.dut.rst.setimmediatevalue(0)
         await RisingEdge(self.dut.clk)
         await RisingEdge(self.dut.clk)
-        self.dut.rst <= 1
+        self.dut.rst.value = 1
         await RisingEdge(self.dut.clk)
         await RisingEdge(self.dut.clk)
-        self.dut.rst <= 0
+        self.dut.rst.value = 0
         await RisingEdge(self.dut.clk)
         await RisingEdge(self.dut.clk)
 
@@ -79,7 +79,15 @@ async def run_test(dut, payload_lengths=None, payload_data=None, idle_inserter=N
 
     tb = TB(dut)
 
-    id_count = 2**len(tb.source[port].bus.tid)
+    id_width = len(tb.source[0].bus.tid)
+    id_count = 2**id_width
+    id_mask = id_count-1
+
+    src_width = (len(tb.source)-1).bit_length()
+    src_mask = 2**src_width-1 if src_width else 0
+    src_shift = id_width-src_width
+    max_count = 2**src_shift
+    count_mask = max_count-1
 
     cur_id = 1
 
@@ -92,19 +100,21 @@ async def run_test(dut, payload_lengths=None, payload_data=None, idle_inserter=N
 
     for test_data in [payload_data(x) for x in payload_lengths()]:
         test_frame = AxiStreamFrame(test_data)
-        test_frame.tid = cur_id
+        test_frame.tid = cur_id | (port << src_shift)
         test_frame.tdest = cur_id
 
         test_frames.append(test_frame)
         await tb.source[port].send(test_frame)
 
-        cur_id = (cur_id + 1) % id_count
+        cur_id = (cur_id + 1) % max_count
 
     for test_frame in test_frames:
         rx_frame = await tb.sink.recv()
 
         assert rx_frame.tdata == test_frame.tdata
-        assert rx_frame.tid == test_frame.tid
+        assert (rx_frame.tid & id_mask) == test_frame.tid
+        assert ((rx_frame.tid >> src_shift) & src_mask) == port
+        assert (rx_frame.tid >> id_width) == port
         assert rx_frame.tdest == test_frame.tdest
         assert not rx_frame.tuser
 
@@ -140,7 +150,15 @@ async def run_arb_test(dut):
     tb = TB(dut)
 
     byte_lanes = tb.source[0].byte_lanes
-    id_count = 2**len(tb.source[0].bus.tid)
+    id_width = len(tb.source[0].bus.tid)
+    id_count = 2**id_width
+    id_mask = id_count-1
+
+    src_width = (len(tb.source)-1).bit_length()
+    src_mask = 2**src_width-1 if src_width else 0
+    src_shift = id_width-src_width
+    max_count = 2**src_shift
+    count_mask = max_count-1
 
     cur_id = 1
 
@@ -153,29 +171,34 @@ async def run_arb_test(dut):
 
     for k in range(5):
         test_frame = AxiStreamFrame(test_data, tx_complete=Event())
-        test_frame.tid = cur_id
+
+        src_ind = 0
 
         if k == 0:
-            test_frame.tdest = 0
+            src_ind = 0
         elif k == 4:
             await test_frames[1].tx_complete.wait()
             for j in range(8):
                 await RisingEdge(dut.clk)
-            test_frame.tdest = 0
+            src_ind = 0
         else:
-            test_frame.tdest = 1
+            src_ind = 1
+
+        test_frame.tid = cur_id | (src_ind << src_shift)
+        test_frame.tdest = 0
 
         test_frames.append(test_frame)
-        await tb.source[test_frame.tdest].send(test_frame)
+        await tb.source[src_ind].send(test_frame)
 
-        cur_id = (cur_id + 1) % id_count
+        cur_id = (cur_id + 1) % max_count
 
     for k in [0, 1, 2, 4, 3]:
         test_frame = test_frames[k]
         rx_frame = await tb.sink.recv()
 
         assert rx_frame.tdata == test_frame.tdata
-        assert rx_frame.tid == test_frame.tid
+        assert (rx_frame.tid & id_mask) == test_frame.tid
+        assert ((rx_frame.tid >> src_shift) & src_mask) == (rx_frame.tid >> id_width)
         assert rx_frame.tdest == test_frame.tdest
         assert not rx_frame.tuser
 
@@ -190,7 +213,15 @@ async def run_stress_test(dut, idle_inserter=None, backpressure_inserter=None):
     tb = TB(dut)
 
     byte_lanes = tb.source[0].byte_lanes
-    id_count = 2**len(tb.source[0].bus.tid)
+    id_width = len(tb.source[0].bus.tid)
+    id_count = 2**id_width
+    id_mask = id_count-1
+
+    src_width = (len(tb.source)-1).bit_length()
+    src_mask = 2**src_width-1 if src_width else 0
+    src_shift = id_width-src_width
+    max_count = 2**src_shift
+    count_mask = max_count-1
 
     cur_id = 1
 
@@ -206,13 +237,13 @@ async def run_stress_test(dut, idle_inserter=None, backpressure_inserter=None):
             length = random.randint(1, byte_lanes*16)
             test_data = bytearray(itertools.islice(itertools.cycle(range(256)), length))
             test_frame = AxiStreamFrame(test_data)
-            test_frame.tid = p
+            test_frame.tid = cur_id | (p << src_shift)
             test_frame.tdest = cur_id
 
             test_frames[p].append(test_frame)
             await tb.source[p].send(test_frame)
 
-            cur_id = (cur_id + 1) % id_count
+            cur_id = (cur_id + 1) % max_count
 
     while any(test_frames):
         rx_frame = await tb.sink.recv()
@@ -220,14 +251,15 @@ async def run_stress_test(dut, idle_inserter=None, backpressure_inserter=None):
         test_frame = None
 
         for lst in test_frames:
-            if lst and lst[0].tid == rx_frame.tid:
+            if lst and lst[0].tid == (rx_frame.tid & id_mask):
                 test_frame = lst.pop(0)
                 break
 
         assert test_frame is not None
 
         assert rx_frame.tdata == test_frame.tdata
-        assert rx_frame.tid == test_frame.tid
+        assert (rx_frame.tid & id_mask) == test_frame.tid
+        assert ((rx_frame.tid >> src_shift) & src_mask) == (rx_frame.tid >> id_width)
         assert rx_frame.tdest == test_frame.tdest
         assert not rx_frame.tuser
 
@@ -310,23 +342,25 @@ def test_axis_arb_mux(request, ports, data_width, round_robin):
 
     parameters = {}
 
-    parameters['PORTS'] = ports
-
     parameters['DATA_WIDTH'] = data_width
     parameters['KEEP_ENABLE'] = int(parameters['DATA_WIDTH'] > 8)
     parameters['KEEP_WIDTH'] = parameters['DATA_WIDTH'] // 8
     parameters['ID_ENABLE'] = 1
-    parameters['ID_WIDTH'] = 8
+    parameters['S_ID_WIDTH'] = 8
+    parameters['M_ID_WIDTH'] = parameters['S_ID_WIDTH'] + (ports-1).bit_length()
     parameters['DEST_ENABLE'] = 1
     parameters['DEST_WIDTH'] = 8
     parameters['USER_ENABLE'] = 1
     parameters['USER_WIDTH'] = 1
     parameters['LAST_ENABLE'] = 1
+    parameters['UPDATE_TID'] = 1
     parameters['ARB_TYPE_ROUND_ROBIN'] = round_robin
     parameters['ARB_LSB_HIGH_PRIORITY'] = 1
 
     extra_env = {f'PARAM_{k}': str(v) for k, v in parameters.items()}
 
+    extra_env['PORTS'] = str(ports)
+
     sim_build = os.path.join(tests_dir, "sim_build",
         request.node.name.replace('[', '-').replace(']', ''))
 
diff --git a/lib/axis/tb/axis_async_fifo/test_axis_async_fifo.py b/lib/axis/tb/axis_async_fifo/test_axis_async_fifo.py
index a159eb38f58174dc09b0df020b7b19f03368aaeb..e064f98b43608ee9be038f984754aac14f7d7a63 100644
--- a/lib/axis/tb/axis_async_fifo/test_axis_async_fifo.py
+++ b/lib/axis/tb/axis_async_fifo/test_axis_async_fifo.py
@@ -68,12 +68,12 @@ class TB(object):
         self.dut.s_rst.setimmediatevalue(0)
         for k in range(10):
             await RisingEdge(self.dut.s_clk)
-        self.dut.m_rst <= 1
-        self.dut.s_rst <= 1
+        self.dut.m_rst.value = 1
+        self.dut.s_rst.value = 1
         for k in range(10):
             await RisingEdge(self.dut.s_clk)
-        self.dut.m_rst <= 0
-        self.dut.s_rst <= 0
+        self.dut.m_rst.value = 0
+        self.dut.s_rst.value = 0
         for k in range(10):
             await RisingEdge(self.dut.s_clk)
 
@@ -81,10 +81,10 @@ class TB(object):
         self.dut.s_rst.setimmediatevalue(0)
         for k in range(10):
             await RisingEdge(self.dut.s_clk)
-        self.dut.s_rst <= 1
+        self.dut.s_rst.value = 1
         for k in range(10):
             await RisingEdge(self.dut.s_clk)
-        self.dut.s_rst <= 0
+        self.dut.s_rst.value = 0
         for k in range(10):
             await RisingEdge(self.dut.s_clk)
 
@@ -92,10 +92,10 @@ class TB(object):
         self.dut.m_rst.setimmediatevalue(0)
         for k in range(10):
             await RisingEdge(self.dut.m_clk)
-        self.dut.m_rst <= 1
+        self.dut.m_rst.value = 1
         for k in range(10):
             await RisingEdge(self.dut.m_clk)
-        self.dut.m_rst <= 0
+        self.dut.m_rst.value = 0
         for k in range(10):
             await RisingEdge(self.dut.m_clk)
 
diff --git a/lib/axis/tb/axis_async_fifo_adapter/test_axis_async_fifo_adapter.py b/lib/axis/tb/axis_async_fifo_adapter/test_axis_async_fifo_adapter.py
index 521e86e844bf8cf4cd1601e15a491b3e41adab52..e97504f997c81e867c238ae5d132e33963e695a7 100644
--- a/lib/axis/tb/axis_async_fifo_adapter/test_axis_async_fifo_adapter.py
+++ b/lib/axis/tb/axis_async_fifo_adapter/test_axis_async_fifo_adapter.py
@@ -65,12 +65,12 @@ class TB(object):
         self.dut.s_rst.setimmediatevalue(0)
         for k in range(10):
             await RisingEdge(self.dut.s_clk)
-        self.dut.m_rst <= 1
-        self.dut.s_rst <= 1
+        self.dut.m_rst.value = 1
+        self.dut.s_rst.value = 1
         for k in range(10):
             await RisingEdge(self.dut.s_clk)
-        self.dut.m_rst <= 0
-        self.dut.s_rst <= 0
+        self.dut.m_rst.value = 0
+        self.dut.s_rst.value = 0
         for k in range(10):
             await RisingEdge(self.dut.s_clk)
 
@@ -78,10 +78,10 @@ class TB(object):
         self.dut.s_rst.setimmediatevalue(0)
         for k in range(10):
             await RisingEdge(self.dut.s_clk)
-        self.dut.s_rst <= 1
+        self.dut.s_rst.value = 1
         for k in range(10):
             await RisingEdge(self.dut.s_clk)
-        self.dut.s_rst <= 0
+        self.dut.s_rst.value = 0
         for k in range(10):
             await RisingEdge(self.dut.s_clk)
 
@@ -89,10 +89,10 @@ class TB(object):
         self.dut.m_rst.setimmediatevalue(0)
         for k in range(10):
             await RisingEdge(self.dut.m_clk)
-        self.dut.m_rst <= 1
+        self.dut.m_rst.value = 1
         for k in range(10):
             await RisingEdge(self.dut.m_clk)
-        self.dut.m_rst <= 0
+        self.dut.m_rst.value = 0
         for k in range(10):
             await RisingEdge(self.dut.m_clk)
 
diff --git a/lib/axis/tb/axis_broadcast/test_axis_broadcast.py b/lib/axis/tb/axis_broadcast/test_axis_broadcast.py
index 735c8dbbcb7ef0bc3624dc4827fce294512a6205..7099e51fd3f56db38025a25db4b0faf2c9973eba 100644
--- a/lib/axis/tb/axis_broadcast/test_axis_broadcast.py
+++ b/lib/axis/tb/axis_broadcast/test_axis_broadcast.py
@@ -66,10 +66,10 @@ class TB(object):
         self.dut.rst.setimmediatevalue(0)
         await RisingEdge(self.dut.clk)
         await RisingEdge(self.dut.clk)
-        self.dut.rst <= 1
+        self.dut.rst.value = 1
         await RisingEdge(self.dut.clk)
         await RisingEdge(self.dut.clk)
-        self.dut.rst <= 0
+        self.dut.rst.value = 0
         await RisingEdge(self.dut.clk)
         await RisingEdge(self.dut.clk)
 
diff --git a/lib/axis/tb/axis_cobs_decode/test_axis_cobs_decode.py b/lib/axis/tb/axis_cobs_decode/test_axis_cobs_decode.py
index 3a5d6393b30f9a9d9fe66dbc2f4025f98c21900a..54fe33ef1aee76e5dd256bf3c66c2ccbb93483e9 100644
--- a/lib/axis/tb/axis_cobs_decode/test_axis_cobs_decode.py
+++ b/lib/axis/tb/axis_cobs_decode/test_axis_cobs_decode.py
@@ -129,10 +129,10 @@ class TB(object):
         self.dut.rst.setimmediatevalue(0)
         await RisingEdge(self.dut.clk)
         await RisingEdge(self.dut.clk)
-        self.dut.rst <= 1
+        self.dut.rst.value = 1
         await RisingEdge(self.dut.clk)
         await RisingEdge(self.dut.clk)
-        self.dut.rst <= 0
+        self.dut.rst.value = 0
         await RisingEdge(self.dut.clk)
         await RisingEdge(self.dut.clk)
 
diff --git a/lib/axis/tb/axis_cobs_encode/test_axis_cobs_encode.py b/lib/axis/tb/axis_cobs_encode/test_axis_cobs_encode.py
index a5070cd91ea8568907572fd526c9afbe031e6194..ad9a7b6abfc22383ea006bcc8faf5bc489816f7b 100644
--- a/lib/axis/tb/axis_cobs_encode/test_axis_cobs_encode.py
+++ b/lib/axis/tb/axis_cobs_encode/test_axis_cobs_encode.py
@@ -130,10 +130,10 @@ class TB(object):
         self.dut.rst.setimmediatevalue(0)
         await RisingEdge(self.dut.clk)
         await RisingEdge(self.dut.clk)
-        self.dut.rst <= 1
+        self.dut.rst.value = 1
         await RisingEdge(self.dut.clk)
         await RisingEdge(self.dut.clk)
-        self.dut.rst <= 0
+        self.dut.rst.value = 0
         await RisingEdge(self.dut.clk)
         await RisingEdge(self.dut.clk)
 
diff --git a/lib/axis/tb/axis_demux/Makefile b/lib/axis/tb/axis_demux/Makefile
index 8f529ef7781097596096d0da8ce9a3df5c39c52f..6fd45472c44729026641fdf1970bca311c490cc0 100644
--- a/lib/axis/tb/axis_demux/Makefile
+++ b/lib/axis/tb/axis_demux/Makefile
@@ -42,9 +42,11 @@ export PARAM_KEEP_WIDTH ?= $(shell expr $(PARAM_DATA_WIDTH) / 8 )
 export PARAM_ID_ENABLE ?= 1
 export PARAM_ID_WIDTH ?= 8
 export PARAM_DEST_ENABLE ?= 1
-export PARAM_DEST_WIDTH ?= 8
+export PARAM_M_DEST_WIDTH ?= 8
+export PARAM_S_DEST_WIDTH ?= $(shell python -c "print($(PARAM_M_DEST_WIDTH) + ($(PORTS)-1).bit_length())")
 export PARAM_USER_ENABLE ?= 1
 export PARAM_USER_WIDTH ?= 1
+export PARAM_TDEST_ROUTE ?= 1
 
 ifeq ($(SIM), icarus)
 	PLUSARGS += -fst
@@ -55,9 +57,11 @@ ifeq ($(SIM), icarus)
 	COMPILE_ARGS += -P $(TOPLEVEL).ID_ENABLE=$(PARAM_ID_ENABLE)
 	COMPILE_ARGS += -P $(TOPLEVEL).ID_WIDTH=$(PARAM_ID_WIDTH)
 	COMPILE_ARGS += -P $(TOPLEVEL).DEST_ENABLE=$(PARAM_DEST_ENABLE)
-	COMPILE_ARGS += -P $(TOPLEVEL).DEST_WIDTH=$(PARAM_DEST_WIDTH)
+	COMPILE_ARGS += -P $(TOPLEVEL).S_DEST_WIDTH=$(PARAM_S_DEST_WIDTH)
+	COMPILE_ARGS += -P $(TOPLEVEL).M_DEST_WIDTH=$(PARAM_M_DEST_WIDTH)
 	COMPILE_ARGS += -P $(TOPLEVEL).USER_ENABLE=$(PARAM_USER_ENABLE)
 	COMPILE_ARGS += -P $(TOPLEVEL).USER_WIDTH=$(PARAM_USER_WIDTH)
+	COMPILE_ARGS += -P $(TOPLEVEL).TDEST_ROUTE=$(PARAM_TDEST_ROUTE)
 
 	ifeq ($(WAVES), 1)
 		VERILOG_SOURCES += iverilog_dump.v
@@ -72,9 +76,11 @@ else ifeq ($(SIM), verilator)
 	COMPILE_ARGS += -GID_ENABLE=$(PARAM_ID_ENABLE)
 	COMPILE_ARGS += -GID_WIDTH=$(PARAM_ID_WIDTH)
 	COMPILE_ARGS += -GDEST_ENABLE=$(PARAM_DEST_ENABLE)
-	COMPILE_ARGS += -GDEST_WIDTH=$(PARAM_DEST_WIDTH)
+	COMPILE_ARGS += -GS_DEST_WIDTH=$(PARAM_S_DEST_WIDTH)
+	COMPILE_ARGS += -GM_DEST_WIDTH=$(PARAM_M_DEST_WIDTH)
 	COMPILE_ARGS += -GUSER_ENABLE=$(PARAM_USER_ENABLE)
 	COMPILE_ARGS += -GUSER_WIDTH=$(PARAM_USER_WIDTH)
+	COMPILE_ARGS += -GTDEST_ROUTE=$(PARAM_TDEST_ROUTE)
 
 	ifeq ($(WAVES), 1)
 		COMPILE_ARGS += --trace-fst
diff --git a/lib/axis/tb/axis_demux/test_axis_demux.py b/lib/axis/tb/axis_demux/test_axis_demux.py
index 5f980d0c881b6727ee949c3968799234750f07e7..355eac8086b9372ea6a5fa6002d87dfe8999ce4c 100644
--- a/lib/axis/tb/axis_demux/test_axis_demux.py
+++ b/lib/axis/tb/axis_demux/test_axis_demux.py
@@ -43,7 +43,7 @@ class TB(object):
     def __init__(self, dut):
         self.dut = dut
 
-        ports = int(os.getenv("PORTS"))
+        ports = len(dut.axis_demux_inst.m_axis_tvalid)
 
         self.log = logging.getLogger("cocotb.tb")
         self.log.setLevel(logging.DEBUG)
@@ -70,10 +70,10 @@ class TB(object):
         self.dut.rst.setimmediatevalue(0)
         await RisingEdge(self.dut.clk)
         await RisingEdge(self.dut.clk)
-        self.dut.rst <= 1
+        self.dut.rst.value = 1
         await RisingEdge(self.dut.clk)
         await RisingEdge(self.dut.clk)
-        self.dut.rst <= 0
+        self.dut.rst.value = 0
         await RisingEdge(self.dut.clk)
         await RisingEdge(self.dut.clk)
 
@@ -82,7 +82,13 @@ async def run_test(dut, payload_lengths=None, payload_data=None, idle_inserter=N
 
     tb = TB(dut)
 
-    id_count = 2**len(tb.source.bus.tid)
+    id_width = len(tb.source.bus.tid)
+    id_count = 2**id_width
+    id_mask = id_count-1
+
+    dest_width = len(tb.sink[0].bus.tid)
+    dest_count = 2**dest_width
+    dest_mask = dest_count-1
 
     cur_id = 1
 
@@ -100,7 +106,7 @@ async def run_test(dut, payload_lengths=None, payload_data=None, idle_inserter=N
     for test_data in [payload_data(x) for x in payload_lengths()]:
         test_frame = AxiStreamFrame(test_data)
         test_frame.tid = cur_id
-        test_frame.tdest = cur_id
+        test_frame.tdest = cur_id | (port << dest_width)
 
         test_frames.append(test_frame)
         await tb.source.send(test_frame)
@@ -112,7 +118,7 @@ async def run_test(dut, payload_lengths=None, payload_data=None, idle_inserter=N
 
         assert rx_frame.tdata == test_frame.tdata
         assert rx_frame.tid == test_frame.tid
-        assert rx_frame.tdest == test_frame.tdest
+        assert rx_frame.tdest == (test_frame.tdest & dest_mask)
         assert not rx_frame.tuser
 
     assert tb.sink[port].empty()
@@ -137,7 +143,7 @@ def incrementing_payload(length):
 
 if cocotb.SIM_NAME:
 
-    ports = int(os.getenv("PORTS"))
+    ports = len(cocotb.top.axis_demux_inst.m_axis_tvalid)
 
     factory = TestFactory(run_test)
     factory.add_option("payload_lengths", [size_list])
@@ -154,9 +160,10 @@ tests_dir = os.path.dirname(__file__)
 rtl_dir = os.path.abspath(os.path.join(tests_dir, '..', '..', 'rtl'))
 
 
+@pytest.mark.parametrize("tdest_route", [0, 1])
 @pytest.mark.parametrize("data_width", [8, 16, 32])
 @pytest.mark.parametrize("ports", [4])
-def test_axis_demux(request, ports, data_width):
+def test_axis_demux(request, ports, data_width, tdest_route):
     dut = "axis_demux"
     wrapper = f"{dut}_wrap_{ports}"
     module = os.path.splitext(os.path.basename(__file__))[0]
@@ -183,9 +190,11 @@ def test_axis_demux(request, ports, data_width):
     parameters['ID_ENABLE'] = 1
     parameters['ID_WIDTH'] = 8
     parameters['DEST_ENABLE'] = 1
-    parameters['DEST_WIDTH'] = 8
+    parameters['M_DEST_WIDTH'] = 8
+    parameters['S_DEST_WIDTH'] = parameters['M_DEST_WIDTH'] + (ports-1).bit_length()
     parameters['USER_ENABLE'] = 1
     parameters['USER_WIDTH'] = 1
+    parameters['TDEST_ROUTE'] = tdest_route
 
     extra_env = {f'PARAM_{k}': str(v) for k, v in parameters.items()}
 
diff --git a/lib/axis/tb/axis_fifo/test_axis_fifo.py b/lib/axis/tb/axis_fifo/test_axis_fifo.py
index 88367b071452566e4dd26cf3ef420a004ee3ae43..d4bcf3ee4cd4be72796895db15f1e8819e50d0e1 100644
--- a/lib/axis/tb/axis_fifo/test_axis_fifo.py
+++ b/lib/axis/tb/axis_fifo/test_axis_fifo.py
@@ -63,10 +63,10 @@ class TB(object):
         self.dut.rst.setimmediatevalue(0)
         await RisingEdge(self.dut.clk)
         await RisingEdge(self.dut.clk)
-        self.dut.rst <= 1
+        self.dut.rst.value = 1
         await RisingEdge(self.dut.clk)
         await RisingEdge(self.dut.clk)
-        self.dut.rst <= 0
+        self.dut.rst.value = 0
         await RisingEdge(self.dut.clk)
         await RisingEdge(self.dut.clk)
 
diff --git a/lib/axis/tb/axis_fifo_adapter/test_axis_fifo_adapter.py b/lib/axis/tb/axis_fifo_adapter/test_axis_fifo_adapter.py
index f69b502dc044e7a962f8de7d0d6dbfd871e327fa..e5575b44e6e54ab4c29214703df7c55398226f97 100644
--- a/lib/axis/tb/axis_fifo_adapter/test_axis_fifo_adapter.py
+++ b/lib/axis/tb/axis_fifo_adapter/test_axis_fifo_adapter.py
@@ -63,10 +63,10 @@ class TB(object):
         self.dut.rst.setimmediatevalue(0)
         await RisingEdge(self.dut.clk)
         await RisingEdge(self.dut.clk)
-        self.dut.rst <= 1
+        self.dut.rst.value = 1
         await RisingEdge(self.dut.clk)
         await RisingEdge(self.dut.clk)
-        self.dut.rst <= 0
+        self.dut.rst.value = 0
         await RisingEdge(self.dut.clk)
         await RisingEdge(self.dut.clk)
 
diff --git a/lib/axis/tb/axis_frame_length_adjust/test_axis_frame_length_adjust.py b/lib/axis/tb/axis_frame_length_adjust/test_axis_frame_length_adjust.py
index 01785fa03aae3898eef6a5667e70bd9d663bf17b..43d7d2f597a38a0d0874960425335d76c789bf46 100644
--- a/lib/axis/tb/axis_frame_length_adjust/test_axis_frame_length_adjust.py
+++ b/lib/axis/tb/axis_frame_length_adjust/test_axis_frame_length_adjust.py
@@ -74,10 +74,10 @@ class TB(object):
         self.dut.rst.setimmediatevalue(0)
         await RisingEdge(self.dut.clk)
         await RisingEdge(self.dut.clk)
-        self.dut.rst <= 1
+        self.dut.rst.value = 1
         await RisingEdge(self.dut.clk)
         await RisingEdge(self.dut.clk)
-        self.dut.rst <= 0
+        self.dut.rst.value = 0
         await RisingEdge(self.dut.clk)
         await RisingEdge(self.dut.clk)
 
@@ -102,8 +102,8 @@ async def run_test(dut, payload_lengths=None, payload_data=None, idle_inserter=N
             tb.log.info("length_min %d, length_max %d", length_min, length_max)
 
             await RisingEdge(dut.clk)
-            tb.dut.length_min <= length_min
-            tb.dut.length_max <= length_max
+            tb.dut.length_min.value = length_min
+            tb.dut.length_max.value = length_max
             await RisingEdge(dut.clk)
 
             test_frames = []
diff --git a/lib/axis/tb/axis_frame_length_adjust_fifo/test_axis_frame_length_adjust_fifo.py b/lib/axis/tb/axis_frame_length_adjust_fifo/test_axis_frame_length_adjust_fifo.py
index 021ddf3216dfef3dff09b883edbf87f72984113a..5bc902db76dd37e882af80b55998f284f4848431 100644
--- a/lib/axis/tb/axis_frame_length_adjust_fifo/test_axis_frame_length_adjust_fifo.py
+++ b/lib/axis/tb/axis_frame_length_adjust_fifo/test_axis_frame_length_adjust_fifo.py
@@ -74,10 +74,10 @@ class TB(object):
         self.dut.rst.setimmediatevalue(0)
         await RisingEdge(self.dut.clk)
         await RisingEdge(self.dut.clk)
-        self.dut.rst <= 1
+        self.dut.rst.value = 1
         await RisingEdge(self.dut.clk)
         await RisingEdge(self.dut.clk)
-        self.dut.rst <= 0
+        self.dut.rst.value = 0
         await RisingEdge(self.dut.clk)
         await RisingEdge(self.dut.clk)
 
@@ -102,8 +102,8 @@ async def run_test(dut, payload_lengths=None, payload_data=None, idle_inserter=N
             tb.log.info("length_min %d, length_max %d", length_min, length_max)
 
             await RisingEdge(dut.clk)
-            tb.dut.length_min <= length_min
-            tb.dut.length_max <= length_max
+            tb.dut.length_min.value = length_min
+            tb.dut.length_max.value = length_max
             await RisingEdge(dut.clk)
 
             test_frames = []
diff --git a/lib/axis/tb/axis_mux/test_axis_mux.py b/lib/axis/tb/axis_mux/test_axis_mux.py
index 12a06efa0a5efb08951b982a17eae5bee33fc288..1f04e10b46b3b0da3d40a728dd16cdbc1f16d258 100644
--- a/lib/axis/tb/axis_mux/test_axis_mux.py
+++ b/lib/axis/tb/axis_mux/test_axis_mux.py
@@ -69,10 +69,10 @@ class TB(object):
         self.dut.rst.setimmediatevalue(0)
         await RisingEdge(self.dut.clk)
         await RisingEdge(self.dut.clk)
-        self.dut.rst <= 1
+        self.dut.rst.value = 1
         await RisingEdge(self.dut.clk)
         await RisingEdge(self.dut.clk)
-        self.dut.rst <= 0
+        self.dut.rst.value = 0
         await RisingEdge(self.dut.clk)
         await RisingEdge(self.dut.clk)
 
diff --git a/lib/axis/tb/axis_pipeline_fifo/test_axis_pipeline_fifo.py b/lib/axis/tb/axis_pipeline_fifo/test_axis_pipeline_fifo.py
index 080b0755664d4dd5454d8d272bf66ad743ff1d9b..4ed2fb8858ac009103471105a41bf8c9a4d0b43b 100644
--- a/lib/axis/tb/axis_pipeline_fifo/test_axis_pipeline_fifo.py
+++ b/lib/axis/tb/axis_pipeline_fifo/test_axis_pipeline_fifo.py
@@ -63,10 +63,10 @@ class TB(object):
         self.dut.rst.setimmediatevalue(0)
         await RisingEdge(self.dut.clk)
         await RisingEdge(self.dut.clk)
-        self.dut.rst <= 1
+        self.dut.rst.value = 1
         await RisingEdge(self.dut.clk)
         await RisingEdge(self.dut.clk)
-        self.dut.rst <= 0
+        self.dut.rst.value = 0
         await RisingEdge(self.dut.clk)
         await RisingEdge(self.dut.clk)
 
diff --git a/lib/axis/tb/axis_pipeline_register/test_axis_pipeline_register.py b/lib/axis/tb/axis_pipeline_register/test_axis_pipeline_register.py
index 482ee561875c52571ef64391eed560fc1f490230..cbb3c01ab9c4778bad7f7af4f743c18b92be6835 100644
--- a/lib/axis/tb/axis_pipeline_register/test_axis_pipeline_register.py
+++ b/lib/axis/tb/axis_pipeline_register/test_axis_pipeline_register.py
@@ -63,10 +63,10 @@ class TB(object):
         self.dut.rst.setimmediatevalue(0)
         await RisingEdge(self.dut.clk)
         await RisingEdge(self.dut.clk)
-        self.dut.rst <= 1
+        self.dut.rst.value = 1
         await RisingEdge(self.dut.clk)
         await RisingEdge(self.dut.clk)
-        self.dut.rst <= 0
+        self.dut.rst.value = 0
         await RisingEdge(self.dut.clk)
         await RisingEdge(self.dut.clk)
 
diff --git a/lib/axis/tb/axis_ram_switch/Makefile b/lib/axis/tb/axis_ram_switch/Makefile
index c5e6aa307f218e1f7f55d3a286ec099531f6e47e..adfea859f019261fbcd2b520c947fcfa7cbb6e3c 100644
--- a/lib/axis/tb/axis_ram_switch/Makefile
+++ b/lib/axis/tb/axis_ram_switch/Makefile
@@ -26,11 +26,11 @@ WAVES ?= 0
 COCOTB_HDL_TIMEUNIT = 1ns
 COCOTB_HDL_TIMEPRECISION = 1ps
 
-export PARAM_S_COUNT ?= 4
-export PARAM_M_COUNT ?= 4
+export S_COUNT ?= 4
+export M_COUNT ?= 4
 
 DUT      = axis_ram_switch
-WRAPPER  = $(DUT)_wrap_$(PARAM_S_COUNT)x$(PARAM_M_COUNT)
+WRAPPER  = $(DUT)_wrap_$(S_COUNT)x$(M_COUNT)
 TOPLEVEL = $(WRAPPER)
 MODULE   = test_$(DUT)
 VERILOG_SOURCES += $(WRAPPER).v
@@ -50,14 +50,17 @@ export PARAM_M_DATA_WIDTH ?= 8
 export PARAM_M_KEEP_ENABLE ?= $(shell expr $(PARAM_M_DATA_WIDTH) \> 8 )
 export PARAM_M_KEEP_WIDTH ?= $(shell expr $(PARAM_M_DATA_WIDTH) / 8 )
 export PARAM_ID_ENABLE ?= 1
-export PARAM_ID_WIDTH ?= 16
-export PARAM_DEST_WIDTH ?= 8
+export PARAM_S_ID_WIDTH ?= 16
+export PARAM_M_ID_WIDTH ?= $(shell python -c "print($(PARAM_S_ID_WIDTH) + ($(S_COUNT)-1).bit_length())")
+export PARAM_M_DEST_WIDTH ?= 8
+export PARAM_S_DEST_WIDTH ?= $(shell python -c "print($(PARAM_M_DEST_WIDTH) + ($(M_COUNT)-1).bit_length())")
 export PARAM_USER_ENABLE ?= 1
 export PARAM_USER_WIDTH ?= 1
 export PARAM_USER_BAD_FRAME_VALUE ?= 1
 export PARAM_USER_BAD_FRAME_MASK ?= 1
 export PARAM_DROP_BAD_FRAME ?= 0
 export PARAM_DROP_WHEN_FULL ?= 0
+export PARAM_UPDATE_TID ?= 1
 export PARAM_ARB_TYPE_ROUND_ROBIN ?= 1
 export PARAM_ARB_LSB_HIGH_PRIORITY ?= 1
 export PARAM_RAM_PIPELINE ?= 2
@@ -75,14 +78,17 @@ ifeq ($(SIM), icarus)
 	COMPILE_ARGS += -P $(TOPLEVEL).M_KEEP_ENABLE=$(PARAM_M_KEEP_ENABLE)
 	COMPILE_ARGS += -P $(TOPLEVEL).M_KEEP_WIDTH=$(PARAM_M_KEEP_WIDTH)
 	COMPILE_ARGS += -P $(TOPLEVEL).ID_ENABLE=$(PARAM_ID_ENABLE)
-	COMPILE_ARGS += -P $(TOPLEVEL).ID_WIDTH=$(PARAM_ID_WIDTH)
-	COMPILE_ARGS += -P $(TOPLEVEL).DEST_WIDTH=$(PARAM_DEST_WIDTH)
+	COMPILE_ARGS += -P $(TOPLEVEL).S_ID_WIDTH=$(PARAM_S_ID_WIDTH)
+	COMPILE_ARGS += -P $(TOPLEVEL).M_ID_WIDTH=$(PARAM_M_ID_WIDTH)
+	COMPILE_ARGS += -P $(TOPLEVEL).S_DEST_WIDTH=$(PARAM_S_DEST_WIDTH)
+	COMPILE_ARGS += -P $(TOPLEVEL).M_DEST_WIDTH=$(PARAM_M_DEST_WIDTH)
 	COMPILE_ARGS += -P $(TOPLEVEL).USER_ENABLE=$(PARAM_USER_ENABLE)
 	COMPILE_ARGS += -P $(TOPLEVEL).USER_WIDTH=$(PARAM_USER_WIDTH)
 	COMPILE_ARGS += -P $(TOPLEVEL).USER_BAD_FRAME_VALUE=$(PARAM_USER_BAD_FRAME_VALUE)
 	COMPILE_ARGS += -P $(TOPLEVEL).USER_BAD_FRAME_MASK=$(PARAM_USER_BAD_FRAME_MASK)
 	COMPILE_ARGS += -P $(TOPLEVEL).DROP_BAD_FRAME=$(PARAM_DROP_BAD_FRAME)
 	COMPILE_ARGS += -P $(TOPLEVEL).DROP_WHEN_FULL=$(PARAM_DROP_WHEN_FULL)
+	COMPILE_ARGS += -P $(TOPLEVEL).UPDATE_TID=$(PARAM_UPDATE_TID)
 	COMPILE_ARGS += -P $(TOPLEVEL).ARB_TYPE_ROUND_ROBIN=$(PARAM_ARB_TYPE_ROUND_ROBIN)
 	COMPILE_ARGS += -P $(TOPLEVEL).ARB_LSB_HIGH_PRIORITY=$(PARAM_ARB_LSB_HIGH_PRIORITY)
 	COMPILE_ARGS += -P $(TOPLEVEL).RAM_PIPELINE=$(PARAM_RAM_PIPELINE)
@@ -104,14 +110,17 @@ else ifeq ($(SIM), verilator)
 	COMPILE_ARGS += -GM_KEEP_ENABLE=$(PARAM_M_KEEP_ENABLE)
 	COMPILE_ARGS += -GM_KEEP_WIDTH=$(PARAM_M_KEEP_WIDTH)
 	COMPILE_ARGS += -GID_ENABLE=$(PARAM_ID_ENABLE)
-	COMPILE_ARGS += -GID_WIDTH=$(PARAM_ID_WIDTH)
-	COMPILE_ARGS += -GDEST_WIDTH=$(PARAM_DEST_WIDTH)
+	COMPILE_ARGS += -GS_ID_WIDTH=$(PARAM_S_ID_WIDTH)
+	COMPILE_ARGS += -GM_ID_WIDTH=$(PARAM_M_ID_WIDTH)
+	COMPILE_ARGS += -GS_DEST_WIDTH=$(PARAM_S_DEST_WIDTH)
+	COMPILE_ARGS += -GM_DEST_WIDTH=$(PARAM_M_DEST_WIDTH)
 	COMPILE_ARGS += -GUSER_ENABLE=$(PARAM_USER_ENABLE)
 	COMPILE_ARGS += -GUSER_WIDTH=$(PARAM_USER_WIDTH)
 	COMPILE_ARGS += -GUSER_BAD_FRAME_VALUE=$(PARAM_USER_BAD_FRAME_VALUE)
 	COMPILE_ARGS += -GUSER_BAD_FRAME_MASK=$(PARAM_USER_BAD_FRAME_MASK)
 	COMPILE_ARGS += -GDROP_BAD_FRAME=$(PARAM_DROP_BAD_FRAME)
 	COMPILE_ARGS += -GDROP_WHEN_FULL=$(PARAM_DROP_WHEN_FULL)
+	COMPILE_ARGS += -GUPDATE_TID=$(PARAM_UPDATE_TID)
 	COMPILE_ARGS += -GARB_TYPE_ROUND_ROBIN=$(PARAM_ARB_TYPE_ROUND_ROBIN)
 	COMPILE_ARGS += -GARB_LSB_HIGH_PRIORITY=$(PARAM_ARB_LSB_HIGH_PRIORITY)
 	COMPILE_ARGS += -GRAM_PIPELINE=$(PARAM_RAM_PIPELINE)
@@ -124,7 +133,7 @@ endif
 include $(shell cocotb-config --makefiles)/Makefile.sim
 
 $(WRAPPER).v: ../../rtl/$(DUT)_wrap.py
-	$< -p $(PARAM_S_COUNT) $(PARAM_M_COUNT)
+	$< -p $(S_COUNT) $(M_COUNT)
 
 iverilog_dump.v:
 	echo 'module iverilog_dump();' > $@
diff --git a/lib/axis/tb/axis_ram_switch/test_axis_ram_switch.py b/lib/axis/tb/axis_ram_switch/test_axis_ram_switch.py
index 62335c81b870ee6f131f94c8e01b42a66c06c0d0..e0071deb900699ea66cd6b6fcaa4cf10c78ab4d6 100644
--- a/lib/axis/tb/axis_ram_switch/test_axis_ram_switch.py
+++ b/lib/axis/tb/axis_ram_switch/test_axis_ram_switch.py
@@ -69,10 +69,10 @@ class TB(object):
         self.dut.rst.setimmediatevalue(0)
         await RisingEdge(self.dut.clk)
         await RisingEdge(self.dut.clk)
-        self.dut.rst <= 1
+        self.dut.rst.value = 1
         await RisingEdge(self.dut.clk)
         await RisingEdge(self.dut.clk)
-        self.dut.rst <= 0
+        self.dut.rst.value = 0
         await RisingEdge(self.dut.clk)
         await RisingEdge(self.dut.clk)
 
@@ -81,7 +81,15 @@ async def run_test(dut, payload_lengths=None, payload_data=None, idle_inserter=N
 
     tb = TB(dut)
 
-    id_count = 2**len(tb.source[s].bus.tid)
+    id_width = len(tb.source[0].bus.tid)
+    id_count = 2**id_width
+    id_mask = id_count-1
+
+    src_width = (len(tb.source)-1).bit_length()
+    src_mask = 2**src_width-1 if src_width else 0
+    src_shift = id_width-src_width
+    max_count = 2**src_shift
+    count_mask = max_count-1
 
     cur_id = 1
 
@@ -94,19 +102,21 @@ async def run_test(dut, payload_lengths=None, payload_data=None, idle_inserter=N
 
     for test_data in [payload_data(x) for x in payload_lengths()]:
         test_frame = AxiStreamFrame(test_data)
-        test_frame.tid = cur_id
+        test_frame.tid = cur_id | (s << src_shift)
         test_frame.tdest = m
 
         test_frames.append(test_frame)
         await tb.source[s].send(test_frame)
 
-        cur_id = (cur_id + 1) % id_count
+        cur_id = (cur_id + 1) % max_count
 
     for test_frame in test_frames:
         rx_frame = await tb.sink[m].recv()
 
         assert rx_frame.tdata == test_frame.tdata
-        assert rx_frame.tid == test_frame.tid
+        assert (rx_frame.tid & id_mask) == test_frame.tid
+        assert ((rx_frame.tid >> src_shift) & src_mask) == s
+        assert (rx_frame.tid >> id_width) == s
         assert rx_frame.tdest == test_frame.tdest
         assert not rx_frame.tuser
 
@@ -142,7 +152,15 @@ async def run_arb_test(dut):
     tb = TB(dut)
 
     byte_lanes = max(tb.source[0].byte_lanes, tb.sink[0].byte_lanes)
-    id_count = 2**len(tb.source[0].bus.tid)
+    id_width = len(tb.source[0].bus.tid)
+    id_count = 2**id_width
+    id_mask = id_count-1
+
+    src_width = (len(tb.source)-1).bit_length()
+    src_mask = 2**src_width-1 if src_width else 0
+    src_shift = id_width-src_width
+    max_count = 2**src_shift
+    count_mask = max_count-1
 
     cur_id = 1
 
@@ -155,8 +173,6 @@ async def run_arb_test(dut):
 
     for k in range(5):
         test_frame = AxiStreamFrame(test_data, tx_complete=Event())
-        test_frame.tid = cur_id
-        test_frame.tdest = 0
 
         src_ind = 0
 
@@ -170,17 +186,21 @@ async def run_arb_test(dut):
         else:
             src_ind = 1
 
+        test_frame.tid = cur_id | (src_ind << src_shift)
+        test_frame.tdest = 0
+
         test_frames.append(test_frame)
         await tb.source[src_ind].send(test_frame)
 
-        cur_id = (cur_id + 1) % id_count
+        cur_id = (cur_id + 1) % max_count
 
     for k in [0, 1, 2, 4, 3]:
         test_frame = test_frames[k]
         rx_frame = await tb.sink[0].recv()
 
         assert rx_frame.tdata == test_frame.tdata
-        assert rx_frame.tid == test_frame.tid
+        assert (rx_frame.tid & id_mask) == test_frame.tid
+        assert ((rx_frame.tid >> src_shift) & src_mask) == (rx_frame.tid >> id_width)
         assert rx_frame.tdest == test_frame.tdest
         assert not rx_frame.tuser
 
@@ -195,7 +215,15 @@ async def run_stress_test(dut, idle_inserter=None, backpressure_inserter=None):
     tb = TB(dut)
 
     byte_lanes = max(tb.source[0].byte_lanes, tb.sink[0].byte_lanes)
-    id_count = 2**len(tb.source[0].bus.tid)
+    id_width = len(tb.source[0].bus.tid)
+    id_count = 2**id_width
+    id_mask = id_count-1
+
+    src_width = (len(tb.source)-1).bit_length()
+    src_mask = 2**src_width-1 if src_width else 0
+    src_shift = id_width-src_width
+    max_count = 2**src_shift
+    count_mask = max_count-1
 
     cur_id = 1
 
@@ -211,13 +239,13 @@ async def run_stress_test(dut, idle_inserter=None, backpressure_inserter=None):
             length = random.randint(1, byte_lanes*16)
             test_data = bytearray(itertools.islice(itertools.cycle(range(256)), length))
             test_frame = AxiStreamFrame(test_data)
-            test_frame.tid = cur_id
+            test_frame.tid = cur_id | (p << src_shift)
             test_frame.tdest = random.randrange(len(tb.sink))
 
             test_frames[p][test_frame.tdest].append(test_frame)
             await tb.source[p].send(test_frame)
 
-            cur_id = (cur_id + 1) % id_count
+            cur_id = (cur_id + 1) % max_count
 
     for lst in test_frames:
         while any(lst):
@@ -227,14 +255,15 @@ async def run_stress_test(dut, idle_inserter=None, backpressure_inserter=None):
 
             for lst_a in test_frames:
                 for lst_b in lst_a:
-                    if lst_b and lst_b[0].tid == rx_frame.tid:
+                    if lst_b and lst_b[0].tid == (rx_frame.tid & id_mask):
                         test_frame = lst_b.pop(0)
                         break
 
             assert test_frame is not None
 
             assert rx_frame.tdata == test_frame.tdata
-            assert rx_frame.tid == test_frame.tid
+            assert (rx_frame.tid & id_mask) == test_frame.tid
+            assert ((rx_frame.tid >> src_shift) & src_mask) == (rx_frame.tid >> id_width)
             assert rx_frame.tdest == test_frame.tdest
             assert not rx_frame.tuser
 
@@ -322,9 +351,6 @@ def test_axis_ram_switch(request, s_count, m_count, s_data_width, m_data_width):
 
     parameters = {}
 
-    parameters['S_COUNT'] = s_count
-    parameters['M_COUNT'] = m_count
-
     parameters['FIFO_DEPTH'] = 4096
     parameters['CMD_FIFO_DEPTH'] = 32
     parameters['SPEEDUP'] = 0
@@ -335,20 +361,26 @@ def test_axis_ram_switch(request, s_count, m_count, s_data_width, m_data_width):
     parameters['M_KEEP_ENABLE'] = int(parameters['M_DATA_WIDTH'] > 8)
     parameters['M_KEEP_WIDTH'] = parameters['M_DATA_WIDTH'] // 8
     parameters['ID_ENABLE'] = 1
-    parameters['ID_WIDTH'] = 16
-    parameters['DEST_WIDTH'] = 8
+    parameters['S_ID_WIDTH'] = 16
+    parameters['M_ID_WIDTH'] = parameters['S_ID_WIDTH'] + (s_count-1).bit_length()
+    parameters['M_DEST_WIDTH'] = 8
+    parameters['S_DEST_WIDTH'] = parameters['M_DEST_WIDTH'] + (m_count-1).bit_length()
     parameters['USER_ENABLE'] = 1
     parameters['USER_WIDTH'] = 1
     parameters['USER_BAD_FRAME_VALUE'] = 1
     parameters['USER_BAD_FRAME_MASK'] = 1
     parameters['DROP_BAD_FRAME'] = 0
     parameters['DROP_WHEN_FULL'] = 0
+    parameters['UPDATE_TID'] = 1
     parameters['ARB_TYPE_ROUND_ROBIN'] = 1
     parameters['ARB_LSB_HIGH_PRIORITY'] = 1
     parameters['RAM_PIPELINE'] = 2
 
     extra_env = {f'PARAM_{k}': str(v) for k, v in parameters.items()}
 
+    extra_env['S_COUNT'] = str(s_count)
+    extra_env['M_COUNT'] = str(m_count)
+
     sim_build = os.path.join(tests_dir, "sim_build",
         request.node.name.replace('[', '-').replace(']', ''))
 
diff --git a/lib/axis/tb/axis_rate_limit/test_axis_rate_limit.py b/lib/axis/tb/axis_rate_limit/test_axis_rate_limit.py
index 91a16937d85433d53779f2c0299a2370e0a4398f..00dd0a3c780ab27c8c3f2cd910a9b7be059108cf 100644
--- a/lib/axis/tb/axis_rate_limit/test_axis_rate_limit.py
+++ b/lib/axis/tb/axis_rate_limit/test_axis_rate_limit.py
@@ -66,10 +66,10 @@ class TB(object):
         self.dut.rst.setimmediatevalue(0)
         await RisingEdge(self.dut.clk)
         await RisingEdge(self.dut.clk)
-        self.dut.rst <= 1
+        self.dut.rst.value = 1
         await RisingEdge(self.dut.clk)
         await RisingEdge(self.dut.clk)
-        self.dut.rst <= 0
+        self.dut.rst.value = 0
         await RisingEdge(self.dut.clk)
         await RisingEdge(self.dut.clk)
 
@@ -87,8 +87,8 @@ async def run_test(dut, payload_lengths=None, payload_data=None, idle_inserter=N
     tb.set_idle_generator(idle_inserter)
     tb.set_backpressure_generator(backpressure_inserter)
 
-    dut.rate_num <= rate[0]
-    dut.rate_denom <= rate[1]
+    dut.rate_num.value = rate[0]
+    dut.rate_denom.value = rate[1]
     await RisingEdge(dut.clk)
 
     test_frames = []
diff --git a/lib/axis/tb/axis_register/test_axis_register.py b/lib/axis/tb/axis_register/test_axis_register.py
index 198622bd780e8a26ddb4a11e82e83d6ab1c71e7c..87a5e0234917f4dc33af4f72e70f9c18893f8983 100644
--- a/lib/axis/tb/axis_register/test_axis_register.py
+++ b/lib/axis/tb/axis_register/test_axis_register.py
@@ -63,10 +63,10 @@ class TB(object):
         self.dut.rst.setimmediatevalue(0)
         await RisingEdge(self.dut.clk)
         await RisingEdge(self.dut.clk)
-        self.dut.rst <= 1
+        self.dut.rst.value = 1
         await RisingEdge(self.dut.clk)
         await RisingEdge(self.dut.clk)
-        self.dut.rst <= 0
+        self.dut.rst.value = 0
         await RisingEdge(self.dut.clk)
         await RisingEdge(self.dut.clk)
 
diff --git a/lib/axis/tb/axis_srl_fifo/test_axis_srl_fifo.py b/lib/axis/tb/axis_srl_fifo/test_axis_srl_fifo.py
index dc03f257255ccb1405898353ab48f453ebc5753b..634ad73bf916ea1feda1c3103ab7e6504a6e3196 100644
--- a/lib/axis/tb/axis_srl_fifo/test_axis_srl_fifo.py
+++ b/lib/axis/tb/axis_srl_fifo/test_axis_srl_fifo.py
@@ -63,10 +63,10 @@ class TB(object):
         self.dut.rst.setimmediatevalue(0)
         await RisingEdge(self.dut.clk)
         await RisingEdge(self.dut.clk)
-        self.dut.rst <= 1
+        self.dut.rst.value = 1
         await RisingEdge(self.dut.clk)
         await RisingEdge(self.dut.clk)
-        self.dut.rst <= 0
+        self.dut.rst.value = 0
         await RisingEdge(self.dut.clk)
         await RisingEdge(self.dut.clk)
 
diff --git a/lib/axis/tb/axis_srl_register/test_axis_srl_register.py b/lib/axis/tb/axis_srl_register/test_axis_srl_register.py
index de826611dcadc3821d011b0102ccc06d12347551..d9128c8df5ee40e41e918232aa2991a209f875fd 100644
--- a/lib/axis/tb/axis_srl_register/test_axis_srl_register.py
+++ b/lib/axis/tb/axis_srl_register/test_axis_srl_register.py
@@ -63,10 +63,10 @@ class TB(object):
         self.dut.rst.setimmediatevalue(0)
         await RisingEdge(self.dut.clk)
         await RisingEdge(self.dut.clk)
-        self.dut.rst <= 1
+        self.dut.rst.value = 1
         await RisingEdge(self.dut.clk)
         await RisingEdge(self.dut.clk)
-        self.dut.rst <= 0
+        self.dut.rst.value = 0
         await RisingEdge(self.dut.clk)
         await RisingEdge(self.dut.clk)
 
diff --git a/lib/axis/tb/axis_switch/Makefile b/lib/axis/tb/axis_switch/Makefile
index d066c53933d30672b42dcb6566b06bfc2ec1c78d..9cf5559934b5339837f0cc42b2ed4e27c42cf8b3 100644
--- a/lib/axis/tb/axis_switch/Makefile
+++ b/lib/axis/tb/axis_switch/Makefile
@@ -26,11 +26,11 @@ WAVES ?= 0
 COCOTB_HDL_TIMEUNIT = 1ns
 COCOTB_HDL_TIMEPRECISION = 1ps
 
-export PARAM_S_COUNT ?= 4
-export PARAM_M_COUNT ?= 4
+export S_COUNT ?= 4
+export M_COUNT ?= 4
 
 DUT      = axis_switch
-WRAPPER  = $(DUT)_wrap_$(PARAM_S_COUNT)x$(PARAM_M_COUNT)
+WRAPPER  = $(DUT)_wrap_$(S_COUNT)x$(M_COUNT)
 TOPLEVEL = $(WRAPPER)
 MODULE   = test_$(DUT)
 VERILOG_SOURCES += $(WRAPPER).v
@@ -44,10 +44,13 @@ export PARAM_DATA_WIDTH ?= 8
 export PARAM_KEEP_ENABLE ?= $(shell expr $(PARAM_DATA_WIDTH) \> 8 )
 export PARAM_KEEP_WIDTH ?= $(shell expr $(PARAM_DATA_WIDTH) / 8 )
 export PARAM_ID_ENABLE ?= 1
-export PARAM_ID_WIDTH ?= 16
-export PARAM_DEST_WIDTH ?= 8
+export PARAM_S_ID_WIDTH ?= 16
+export PARAM_M_ID_WIDTH ?= $(shell python -c "print($(PARAM_S_ID_WIDTH) + ($(S_COUNT)-1).bit_length())")
+export PARAM_M_DEST_WIDTH ?= 8
+export PARAM_S_DEST_WIDTH ?= $(shell python -c "print($(PARAM_M_DEST_WIDTH) + ($(M_COUNT)-1).bit_length())")
 export PARAM_USER_ENABLE ?= 1
 export PARAM_USER_WIDTH ?= 1
+export PARAM_UPDATE_TID ?= 1
 export PARAM_S_REG_TYPE ?= 0
 export PARAM_M_REG_TYPE ?= 2
 export PARAM_ARB_TYPE_ROUND_ROBIN ?= 1
@@ -60,10 +63,13 @@ ifeq ($(SIM), icarus)
 	COMPILE_ARGS += -P $(TOPLEVEL).KEEP_ENABLE=$(PARAM_KEEP_ENABLE)
 	COMPILE_ARGS += -P $(TOPLEVEL).KEEP_WIDTH=$(PARAM_KEEP_WIDTH)
 	COMPILE_ARGS += -P $(TOPLEVEL).ID_ENABLE=$(PARAM_ID_ENABLE)
-	COMPILE_ARGS += -P $(TOPLEVEL).ID_WIDTH=$(PARAM_ID_WIDTH)
-	COMPILE_ARGS += -P $(TOPLEVEL).DEST_WIDTH=$(PARAM_DEST_WIDTH)
+	COMPILE_ARGS += -P $(TOPLEVEL).S_ID_WIDTH=$(PARAM_S_ID_WIDTH)
+	COMPILE_ARGS += -P $(TOPLEVEL).M_ID_WIDTH=$(PARAM_M_ID_WIDTH)
+	COMPILE_ARGS += -P $(TOPLEVEL).S_DEST_WIDTH=$(PARAM_S_DEST_WIDTH)
+	COMPILE_ARGS += -P $(TOPLEVEL).M_DEST_WIDTH=$(PARAM_M_DEST_WIDTH)
 	COMPILE_ARGS += -P $(TOPLEVEL).USER_ENABLE=$(PARAM_USER_ENABLE)
 	COMPILE_ARGS += -P $(TOPLEVEL).USER_WIDTH=$(PARAM_USER_WIDTH)
+	COMPILE_ARGS += -P $(TOPLEVEL).UPDATE_TID=$(PARAM_UPDATE_TID)
 	COMPILE_ARGS += -P $(TOPLEVEL).S_REG_TYPE=$(PARAM_S_REG_TYPE)
 	COMPILE_ARGS += -P $(TOPLEVEL).M_REG_TYPE=$(PARAM_M_REG_TYPE)
 	COMPILE_ARGS += -P $(TOPLEVEL).ARB_TYPE_ROUND_ROBIN=$(PARAM_ARB_TYPE_ROUND_ROBIN)
@@ -80,10 +86,13 @@ else ifeq ($(SIM), verilator)
 	COMPILE_ARGS += -GKEEP_ENABLE=$(PARAM_KEEP_ENABLE)
 	COMPILE_ARGS += -GKEEP_WIDTH=$(PARAM_KEEP_WIDTH)
 	COMPILE_ARGS += -GID_ENABLE=$(PARAM_ID_ENABLE)
-	COMPILE_ARGS += -GID_WIDTH=$(PARAM_ID_WIDTH)
-	COMPILE_ARGS += -GDEST_WIDTH=$(PARAM_DEST_WIDTH)
+	COMPILE_ARGS += -GS_ID_WIDTH=$(PARAM_S_ID_WIDTH)
+	COMPILE_ARGS += -GM_ID_WIDTH=$(PARAM_M_ID_WIDTH)
+	COMPILE_ARGS += -GS_DEST_WIDTH=$(PARAM_S_DEST_WIDTH)
+	COMPILE_ARGS += -GM_DEST_WIDTH=$(PARAM_M_DEST_WIDTH)
 	COMPILE_ARGS += -GUSER_ENABLE=$(PARAM_USER_ENABLE)
 	COMPILE_ARGS += -GUSER_WIDTH=$(PARAM_USER_WIDTH)
+	COMPILE_ARGS += -GUPDATE_TID=$(PARAM_UPDATE_TID)
 	COMPILE_ARGS += -GS_REG_TYPE=$(PARAM_S_REG_TYPE)
 	COMPILE_ARGS += -GM_REG_TYPE=$(PARAM_M_REG_TYPE)
 	COMPILE_ARGS += -GARB_TYPE_ROUND_ROBIN=$(PARAM_ARB_TYPE_ROUND_ROBIN)
@@ -97,7 +106,7 @@ endif
 include $(shell cocotb-config --makefiles)/Makefile.sim
 
 $(WRAPPER).v: ../../rtl/$(DUT)_wrap.py
-	$< -p $(PARAM_S_COUNT) $(PARAM_M_COUNT)
+	$< -p $(S_COUNT) $(M_COUNT)
 
 iverilog_dump.v:
 	echo 'module iverilog_dump();' > $@
diff --git a/lib/axis/tb/axis_switch/test_axis_switch.py b/lib/axis/tb/axis_switch/test_axis_switch.py
index 7132a00db6c2daf5401d45cfbca3c819a102aaa4..6bc84d1a79dc469b6090e99f2be25a70fc220a7f 100644
--- a/lib/axis/tb/axis_switch/test_axis_switch.py
+++ b/lib/axis/tb/axis_switch/test_axis_switch.py
@@ -69,10 +69,10 @@ class TB(object):
         self.dut.rst.setimmediatevalue(0)
         await RisingEdge(self.dut.clk)
         await RisingEdge(self.dut.clk)
-        self.dut.rst <= 1
+        self.dut.rst.value = 1
         await RisingEdge(self.dut.clk)
         await RisingEdge(self.dut.clk)
-        self.dut.rst <= 0
+        self.dut.rst.value = 0
         await RisingEdge(self.dut.clk)
         await RisingEdge(self.dut.clk)
 
@@ -81,7 +81,15 @@ async def run_test(dut, payload_lengths=None, payload_data=None, idle_inserter=N
 
     tb = TB(dut)
 
-    id_count = 2**len(tb.source[s].bus.tid)
+    id_width = len(tb.source[0].bus.tid)
+    id_count = 2**id_width
+    id_mask = id_count-1
+
+    src_width = (len(tb.source)-1).bit_length()
+    src_mask = 2**src_width-1 if src_width else 0
+    src_shift = id_width-src_width
+    max_count = 2**src_shift
+    count_mask = max_count-1
 
     cur_id = 1
 
@@ -94,19 +102,21 @@ async def run_test(dut, payload_lengths=None, payload_data=None, idle_inserter=N
 
     for test_data in [payload_data(x) for x in payload_lengths()]:
         test_frame = AxiStreamFrame(test_data)
-        test_frame.tid = cur_id
+        test_frame.tid = cur_id | (s << src_shift)
         test_frame.tdest = m
 
         test_frames.append(test_frame)
         await tb.source[s].send(test_frame)
 
-        cur_id = (cur_id + 1) % id_count
+        cur_id = (cur_id + 1) % max_count
 
     for test_frame in test_frames:
         rx_frame = await tb.sink[m].recv()
 
         assert rx_frame.tdata == test_frame.tdata
-        assert rx_frame.tid == test_frame.tid
+        assert (rx_frame.tid & id_mask) == test_frame.tid
+        assert ((rx_frame.tid >> src_shift) & src_mask) == s
+        assert (rx_frame.tid >> id_width) == s
         assert rx_frame.tdest == test_frame.tdest
         assert not rx_frame.tuser
 
@@ -142,7 +152,15 @@ async def run_arb_test(dut):
     tb = TB(dut)
 
     byte_lanes = tb.source[0].byte_lanes
-    id_count = 2**len(tb.source[0].bus.tid)
+    id_width = len(tb.source[0].bus.tid)
+    id_count = 2**id_width
+    id_mask = id_count-1
+
+    src_width = (len(tb.source)-1).bit_length()
+    src_mask = 2**src_width-1 if src_width else 0
+    src_shift = id_width-src_width
+    max_count = 2**src_shift
+    count_mask = max_count-1
 
     cur_id = 1
 
@@ -155,8 +173,6 @@ async def run_arb_test(dut):
 
     for k in range(5):
         test_frame = AxiStreamFrame(test_data, tx_complete=Event())
-        test_frame.tid = cur_id
-        test_frame.tdest = 0
 
         src_ind = 0
 
@@ -170,17 +186,21 @@ async def run_arb_test(dut):
         else:
             src_ind = 1
 
+        test_frame.tid = cur_id | (src_ind << src_shift)
+        test_frame.tdest = 0
+
         test_frames.append(test_frame)
         await tb.source[src_ind].send(test_frame)
 
-        cur_id = (cur_id + 1) % id_count
+        cur_id = (cur_id + 1) % max_count
 
     for k in [0, 1, 2, 4, 3]:
         test_frame = test_frames[k]
         rx_frame = await tb.sink[0].recv()
 
         assert rx_frame.tdata == test_frame.tdata
-        assert rx_frame.tid == test_frame.tid
+        assert (rx_frame.tid & id_mask) == test_frame.tid
+        assert ((rx_frame.tid >> src_shift) & src_mask) == (rx_frame.tid >> id_width)
         assert rx_frame.tdest == test_frame.tdest
         assert not rx_frame.tuser
 
@@ -195,7 +215,15 @@ async def run_stress_test(dut, idle_inserter=None, backpressure_inserter=None):
     tb = TB(dut)
 
     byte_lanes = tb.source[0].byte_lanes
-    id_count = 2**len(tb.source[0].bus.tid)
+    id_width = len(tb.source[0].bus.tid)
+    id_count = 2**id_width
+    id_mask = id_count-1
+
+    src_width = (len(tb.source)-1).bit_length()
+    src_mask = 2**src_width-1 if src_width else 0
+    src_shift = id_width-src_width
+    max_count = 2**src_shift
+    count_mask = max_count-1
 
     cur_id = 1
 
@@ -211,13 +239,13 @@ async def run_stress_test(dut, idle_inserter=None, backpressure_inserter=None):
             length = random.randint(1, byte_lanes*16)
             test_data = bytearray(itertools.islice(itertools.cycle(range(256)), length))
             test_frame = AxiStreamFrame(test_data)
-            test_frame.tid = cur_id
+            test_frame.tid = cur_id | (p << src_shift)
             test_frame.tdest = random.randrange(len(tb.sink))
 
             test_frames[p][test_frame.tdest].append(test_frame)
             await tb.source[p].send(test_frame)
 
-            cur_id = (cur_id + 1) % id_count
+            cur_id = (cur_id + 1) % max_count
 
     for lst in test_frames:
         while any(lst):
@@ -227,15 +255,15 @@ async def run_stress_test(dut, idle_inserter=None, backpressure_inserter=None):
 
             for lst_a in test_frames:
                 for lst_b in lst_a:
-                    if lst_b and lst_b[0].tid == rx_frame.tid:
+                    if lst_b and lst_b[0].tid == (rx_frame.tid & id_mask):
                         test_frame = lst_b.pop(0)
                         break
 
             assert test_frame is not None
 
             assert rx_frame.tdata == test_frame.tdata
-            assert rx_frame.tid == test_frame.tid
-            assert rx_frame.tdest == test_frame.tdest
+            assert (rx_frame.tid & id_mask) == test_frame.tid
+            assert ((rx_frame.tid >> src_shift) & src_mask) == (rx_frame.tid >> id_width)
             assert not rx_frame.tuser
 
     assert all(s.empty() for s in tb.sink)
@@ -321,17 +349,17 @@ def test_axis_switch(request, s_count, m_count, data_width):
 
     parameters = {}
 
-    parameters['S_COUNT'] = s_count
-    parameters['M_COUNT'] = m_count
-
     parameters['DATA_WIDTH'] = data_width
     parameters['KEEP_ENABLE'] = int(parameters['DATA_WIDTH'] > 8)
     parameters['KEEP_WIDTH'] = parameters['DATA_WIDTH'] // 8
     parameters['ID_ENABLE'] = 1
-    parameters['ID_WIDTH'] = 16
-    parameters['DEST_WIDTH'] = 8
+    parameters['S_ID_WIDTH'] = 16
+    parameters['M_ID_WIDTH'] = parameters['S_ID_WIDTH'] + (s_count-1).bit_length()
+    parameters['M_DEST_WIDTH'] = 8
+    parameters['S_DEST_WIDTH'] = parameters['M_DEST_WIDTH'] + (m_count-1).bit_length()
     parameters['USER_ENABLE'] = 1
     parameters['USER_WIDTH'] = 1
+    parameters['UPDATE_TID'] = 1
     parameters['S_REG_TYPE'] = 0
     parameters['M_REG_TYPE'] = 2
     parameters['ARB_TYPE_ROUND_ROBIN'] = 1
@@ -339,6 +367,9 @@ def test_axis_switch(request, s_count, m_count, data_width):
 
     extra_env = {f'PARAM_{k}': str(v) for k, v in parameters.items()}
 
+    extra_env['S_COUNT'] = str(s_count)
+    extra_env['M_COUNT'] = str(m_count)
+
     sim_build = os.path.join(tests_dir, "sim_build",
         request.node.name.replace('[', '-').replace(']', ''))