From 4a468910c97da78260b94691eef22d4ebe961610 Mon Sep 17 00:00:00 2001
From: Ryan Izard <rizard@g.clemson.edu>
Date: Fri, 7 Nov 2014 10:05:17 -0500
Subject: [PATCH] Firewall working and unit tests working. Found an fixed a
 couple bugs in the Firewall during the process.

---
 .../firewall/Firewall.java                    |  17 +-
 .../firewall/FirewallRule.java                | 206 +++++++++++-------
 .../firewall/FirewallTest.java                |  12 +-
 3 files changed, 145 insertions(+), 90 deletions(-)

diff --git a/src/main/java/net/floodlightcontroller/firewall/Firewall.java b/src/main/java/net/floodlightcontroller/firewall/Firewall.java
index c50097758..d77bed580 100644
--- a/src/main/java/net/floodlightcontroller/firewall/Firewall.java
+++ b/src/main/java/net/floodlightcontroller/firewall/Firewall.java
@@ -31,6 +31,7 @@ import org.projectfloodlight.openflow.protocol.OFVersion;
 import org.projectfloodlight.openflow.protocol.match.MatchField;
 import org.projectfloodlight.openflow.types.DatapathId;
 import org.projectfloodlight.openflow.types.EthType;
+import org.projectfloodlight.openflow.types.IPv4Address;
 import org.projectfloodlight.openflow.types.IPv4AddressWithMask;
 import org.projectfloodlight.openflow.types.IpProtocol;
 import org.projectfloodlight.openflow.types.MacAddress;
@@ -80,7 +81,7 @@ public class Firewall implements IFirewallService, IOFMessageListener,
 
     protected List<FirewallRule> rules; // protected by synchronized
     protected boolean enabled;
-    protected int subnet_mask = IPv4.toIPv4Address("255.255.255.0");
+    protected IPv4Address subnet_mask = IPv4Address.of("255.255.255.0");
 
     // constant strings for storage/parsing
     public static final String TABLE_NAME = "controller_firewallrules";
@@ -349,14 +350,14 @@ public class Firewall implements IFirewallService, IOFMessageListener,
 
     @Override
     public String getSubnetMask() {
-        return IPv4.fromIPv4Address(this.subnet_mask);
+        return this.subnet_mask.toString();
     }
 
     @Override
     public void setSubnetMask(String newMask) {
         if (newMask.trim().isEmpty())
             return;
-        this.subnet_mask = IPv4.toIPv4Address(newMask.trim());
+        this.subnet_mask = IPv4Address.of(newMask.trim());
     }
 
     @Override
@@ -497,10 +498,10 @@ public class Firewall implements IFirewallService, IOFMessageListener,
      *            the IP address to check
      * @return true if it is a broadcast address, false otherwise
      */
-    protected boolean IPIsBroadcast(int IPAddress) {
+    protected boolean isIPBroadcast(IPv4Address ip) {
         // inverted subnet mask
-        int inv_subnet_mask = ~this.subnet_mask;
-        return ((IPAddress & inv_subnet_mask) == inv_subnet_mask);
+        IPv4Address inv_subnet_mask = subnet_mask.not();
+        return ip.and(inv_subnet_mask).equals(inv_subnet_mask);
     }
 
     public Command processPacketInMessage(IOFSwitch sw, OFPacketIn pi, IRoutingDecision decision, FloodlightContext cntx) {
@@ -511,9 +512,9 @@ public class Firewall implements IFirewallService, IOFMessageListener,
         // broadcasts -> L2 broadcast + L3 unicast)
         if (eth.isBroadcast() == true) {
             boolean allowBroadcast = true;
-            // the case to determine if we have L2 broadcast + L3 unicast
+            // the case to determine if we have L2 broadcast + L3 unicast (L3 broadcast default set to /24 or 255.255.255.0)
             // don't allow this broadcast packet if such is the case (malformed packet)
-            if ((eth.getPayload() instanceof IPv4) && (((IPv4) eth.getPayload()).getDestinationAddress().isBroadcast() == false)) {
+            if ((eth.getPayload() instanceof IPv4) && !isIPBroadcast(((IPv4) eth.getPayload()).getDestinationAddress())) {
                 allowBroadcast = false;
             }
             if (allowBroadcast == true) {
diff --git a/src/main/java/net/floodlightcontroller/firewall/FirewallRule.java b/src/main/java/net/floodlightcontroller/firewall/FirewallRule.java
index fdb3c1b27..9f6fc1e21 100644
--- a/src/main/java/net/floodlightcontroller/firewall/FirewallRule.java
+++ b/src/main/java/net/floodlightcontroller/firewall/FirewallRule.java
@@ -22,7 +22,6 @@ import com.fasterxml.jackson.databind.annotation.JsonSerialize;
 import org.projectfloodlight.openflow.protocol.match.MatchField;
 import org.projectfloodlight.openflow.types.DatapathId;
 import org.projectfloodlight.openflow.types.EthType;
-import org.projectfloodlight.openflow.types.IPv4Address;
 import org.projectfloodlight.openflow.types.IPv4AddressWithMask;
 import org.projectfloodlight.openflow.types.IpProtocol;
 import org.projectfloodlight.openflow.types.MacAddress;
@@ -37,7 +36,95 @@ import net.floodlightcontroller.packet.UDP;
 
 @JsonSerialize(using=FirewallRuleSerializer.class)
 public class FirewallRule implements Comparable<FirewallRule> {
-    public int ruleid;
+    @Override
+	public boolean equals(Object obj) {
+		if (this == obj)
+			return true;
+		if (obj == null)
+			return false;
+		if (getClass() != obj.getClass())
+			return false;
+		FirewallRule other = (FirewallRule) obj;
+		if (action != other.action)
+			return false;
+		if (any_dl_dst != other.any_dl_dst)
+			return false;
+		if (any_dl_src != other.any_dl_src)
+			return false;
+		if (any_dl_type != other.any_dl_type)
+			return false;
+		if (any_dpid != other.any_dpid)
+			return false;
+		if (any_in_port != other.any_in_port)
+			return false;
+		if (any_nw_dst != other.any_nw_dst)
+			return false;
+		if (any_nw_proto != other.any_nw_proto)
+			return false;
+		if (any_nw_src != other.any_nw_src)
+			return false;
+		if (any_tp_dst != other.any_tp_dst)
+			return false;
+		if (any_tp_src != other.any_tp_src)
+			return false;
+		if (dl_dst == null) {
+			if (other.dl_dst != null)
+				return false;
+		} else if (!dl_dst.equals(other.dl_dst))
+			return false;
+		if (dl_src == null) {
+			if (other.dl_src != null)
+				return false;
+		} else if (!dl_src.equals(other.dl_src))
+			return false;
+		if (dl_type == null) {
+			if (other.dl_type != null)
+				return false;
+		} else if (!dl_type.equals(other.dl_type))
+			return false;
+		if (dpid == null) {
+			if (other.dpid != null)
+				return false;
+		} else if (!dpid.equals(other.dpid))
+			return false;
+		if (in_port == null) {
+			if (other.in_port != null)
+				return false;
+		} else if (!in_port.equals(other.in_port))
+			return false;
+		if (nw_dst_prefix_and_mask == null) {
+			if (other.nw_dst_prefix_and_mask != null)
+				return false;
+		} else if (!nw_dst_prefix_and_mask.equals(other.nw_dst_prefix_and_mask))
+			return false;
+		if (nw_proto == null) {
+			if (other.nw_proto != null)
+				return false;
+		} else if (!nw_proto.equals(other.nw_proto))
+			return false;
+		if (nw_src_prefix_and_mask == null) {
+			if (other.nw_src_prefix_and_mask != null)
+				return false;
+		} else if (!nw_src_prefix_and_mask.equals(other.nw_src_prefix_and_mask))
+			return false;
+		if (priority != other.priority)
+			return false;
+		if (ruleid != other.ruleid)
+			return false;
+		if (tp_dst == null) {
+			if (other.tp_dst != null)
+				return false;
+		} else if (!tp_dst.equals(other.tp_dst))
+			return false;
+		if (tp_src == null) {
+			if (other.tp_src != null)
+				return false;
+		} else if (!tp_src.equals(other.tp_src))
+			return false;
+		return true;
+	}
+
+	public int ruleid;
 
     public DatapathId dpid; 
     public OFPort in_port; 
@@ -81,18 +168,16 @@ public class FirewallRule implements Comparable<FirewallRule> {
      * The default rule is to match on anything.
      */
     public FirewallRule() {
+        this.dpid = DatapathId.NONE;
         this.in_port = OFPort.ANY; 
         this.dl_src = MacAddress.NONE;
-        this.nw_src_prefix_and_mask = IPv4AddressWithMask.NONE;
-        //this.nw_src_maskbits = 0; 
         this.dl_dst = MacAddress.NONE;
+        this.dl_type = EthType.NONE;
+        this.nw_src_prefix_and_mask = IPv4AddressWithMask.NONE;
+        this.nw_dst_prefix_and_mask = IPv4AddressWithMask.NONE;
         this.nw_proto = IpProtocol.NONE;
         this.tp_src = TransportPort.NONE;
         this.tp_dst = TransportPort.NONE;
-        this.dl_dst = MacAddress.NONE;
-        this.nw_dst_prefix_and_mask = IPv4AddressWithMask.NONE;
-        //this.nw_dst_maskbits = 0; 
-        this.dpid = DatapathId.NONE;
         this.any_dpid = true; 
         this.any_in_port = true; 
         this.any_dl_src = true; 
@@ -267,7 +352,7 @@ public class FirewallRule implements Comparable<FirewallRule> {
                     pkt_ip = (IPv4) pkt;
 
                     // IP addresses (src and dst) match?
-                    if (any_nw_src == false && this.matchIPAddress(nw_src_prefix_and_mask.getValue().getInt(), nw_src_prefix_and_mask.getMask().getInt(), pkt_ip.getSourceAddress()) == false)
+                    if (any_nw_src == false && !nw_src_prefix_and_mask.matches(pkt_ip.getSourceAddress()))
                         return false;
                     if (action == FirewallRule.FirewallAction.DROP) {
                         //wildcards.drop &= ~OFMatch.OFPFW_NW_SRC_ALL;
@@ -279,7 +364,7 @@ public class FirewallRule implements Comparable<FirewallRule> {
                     	adp.allow.setMasked(MatchField.IPV4_SRC, nw_src_prefix_and_mask);
                     }
 
-                    if (any_nw_dst == false && this.matchIPAddress(nw_dst_prefix_and_mask.getValue().getInt(), nw_dst_prefix_and_mask.getMask().getInt(), pkt_ip.getDestinationAddress()) == false)
+                    if (any_nw_dst == false && !nw_dst_prefix_and_mask.matches(pkt_ip.getDestinationAddress()))
                         return false;
                     if (action == FirewallRule.FirewallAction.DROP) {
                         //wildcards.drop &= ~OFMatch.OFPFW_NW_DST_ALL;
@@ -386,71 +471,40 @@ public class FirewallRule implements Comparable<FirewallRule> {
         return true;
     }
 
-    /**
-     * Determines if rule's CIDR address matches IP address of the packet
-     * 
-     * @param rulePrefix
-     *            prefix part of the CIDR address
-     * @param ruleBits
-     *            the size of mask of the CIDR address
-     * @param packetAddress
-     *            the IP address of the incoming packet to match with
-     * @return true if CIDR address matches the packet's IP address, false
-     *         otherwise
-     */
-    protected boolean matchIPAddress(int rulePrefix, int ruleBits, IPv4Address packetAddress) {
-        boolean matched = true;
-
-        int rule_iprng = 32 - ruleBits;
-        int rule_ipint = rulePrefix;
-        int pkt_ipint = packetAddress.getInt();
-        // if there's a subnet range (bits to be wildcarded > 0)
-        if (rule_iprng > 0) {
-            // right shift bits to remove rule_iprng of LSB that are to be
-            // wildcarded
-            rule_ipint = rule_ipint >> rule_iprng;
-            pkt_ipint = pkt_ipint >> rule_iprng;
-            // now left shift to return to normal range, except that the
-            // rule_iprng number of LSB
-            // are now zeroed
-            rule_ipint = rule_ipint << rule_iprng;
-            pkt_ipint = pkt_ipint << rule_iprng;
-        }
-        // check if we have a match
-        if (rule_ipint != pkt_ipint)
-            matched = false;
-
-        return matched;
-    }
-
     @Override
-    public int hashCode() {
-        final int prime = 2521;
-        int result = super.hashCode();
-        result = prime * result + (int) dpid.getLong();
-        result = prime * result + in_port.getPortNumber();
-        result = prime * result + (int) dl_src.getLong();
-        result = prime * result + (int) dl_dst.getLong();
-        result = prime * result + dl_type.getValue();
-        result = prime * result + nw_src_prefix_and_mask.getValue().getInt();
-        result = prime * result + nw_src_prefix_and_mask.getMask().getInt();
-        result = prime * result + nw_dst_prefix_and_mask.getValue().getInt();
-        result = prime * result + nw_dst_prefix_and_mask.getMask().getInt();
-        result = prime * result + nw_proto.getIpProtocolNumber();
-        result = prime * result + tp_src.getPort();
-        result = prime * result + tp_dst.getPort();
-        result = prime * result + action.ordinal();
-        result = prime * result + priority;
-        result = prime * result + (new Boolean(any_dpid)).hashCode();
-        result = prime * result + (new Boolean(any_in_port)).hashCode();
-        result = prime * result + (new Boolean(any_dl_src)).hashCode();
-        result = prime * result + (new Boolean(any_dl_dst)).hashCode();
-        result = prime * result + (new Boolean(any_dl_type)).hashCode();
-        result = prime * result + (new Boolean(any_nw_src)).hashCode();
-        result = prime * result + (new Boolean(any_nw_dst)).hashCode();
-        result = prime * result + (new Boolean(any_nw_proto)).hashCode();
-        result = prime * result + (new Boolean(any_tp_src)).hashCode();
-        result = prime * result + (new Boolean(any_tp_dst)).hashCode();
-        return result;
-    }
+	public int hashCode() {
+		final int prime = 31;
+		int result = 1;
+		result = prime * result + ((action == null) ? 0 : action.hashCode());
+		result = prime * result + (any_dl_dst ? 1231 : 1237);
+		result = prime * result + (any_dl_src ? 1231 : 1237);
+		result = prime * result + (any_dl_type ? 1231 : 1237);
+		result = prime * result + (any_dpid ? 1231 : 1237);
+		result = prime * result + (any_in_port ? 1231 : 1237);
+		result = prime * result + (any_nw_dst ? 1231 : 1237);
+		result = prime * result + (any_nw_proto ? 1231 : 1237);
+		result = prime * result + (any_nw_src ? 1231 : 1237);
+		result = prime * result + (any_tp_dst ? 1231 : 1237);
+		result = prime * result + (any_tp_src ? 1231 : 1237);
+		result = prime * result + ((dl_dst == null) ? 0 : dl_dst.hashCode());
+		result = prime * result + ((dl_src == null) ? 0 : dl_src.hashCode());
+		result = prime * result + ((dl_type == null) ? 0 : dl_type.hashCode());
+		result = prime * result + ((dpid == null) ? 0 : dpid.hashCode());
+		result = prime * result + ((in_port == null) ? 0 : in_port.hashCode());
+		result = prime
+				* result
+				+ ((nw_dst_prefix_and_mask == null) ? 0
+						: nw_dst_prefix_and_mask.hashCode());
+		result = prime * result
+				+ ((nw_proto == null) ? 0 : nw_proto.hashCode());
+		result = prime
+				* result
+				+ ((nw_src_prefix_and_mask == null) ? 0
+						: nw_src_prefix_and_mask.hashCode());
+		result = prime * result + priority;
+		result = prime * result + ruleid;
+		result = prime * result + ((tp_dst == null) ? 0 : tp_dst.hashCode());
+		result = prime * result + ((tp_src == null) ? 0 : tp_src.hashCode());
+		return result;
+	}
 }
diff --git a/src/test/java/net/floodlightcontroller/firewall/FirewallTest.java b/src/test/java/net/floodlightcontroller/firewall/FirewallTest.java
index 245b369b0..b97550afa 100644
--- a/src/test/java/net/floodlightcontroller/firewall/FirewallTest.java
+++ b/src/test/java/net/floodlightcontroller/firewall/FirewallTest.java
@@ -273,19 +273,19 @@ public class FirewallTest extends FloodlightTestCase {
         List<FirewallRule> rules = firewall.readRulesFromStorage();
         // verify rule 1
         FirewallRule r = rules.get(0);
-        assertEquals(r.in_port, 2);
+        assertEquals(r.in_port, OFPort.of(2));
         assertEquals(r.priority, 1);
         assertEquals(r.dl_src, MacAddress.of("00:00:00:00:00:01"));
         assertEquals(r.dl_dst, MacAddress.of("00:00:00:00:00:02"));
         assertEquals(r.action, FirewallRule.FirewallAction.DROP);
         // verify rule 2
         r = rules.get(1);
-        assertEquals(r.in_port, 3);
+        assertEquals(r.in_port, OFPort.of(3));
         assertEquals(r.priority, 2);
         assertEquals(r.dl_src, MacAddress.of("00:00:00:00:00:02"));
         assertEquals(r.dl_dst, MacAddress.of("00:00:00:00:00:01"));
         assertEquals(r.nw_proto, IpProtocol.TCP);
-        assertEquals(r.tp_dst, 80);
+        assertEquals(r.tp_dst, TransportPort.of(80));
         assertEquals(r.any_nw_proto, false);
         assertEquals(r.action, FirewallRule.FirewallAction.ALLOW);
     }
@@ -358,7 +358,7 @@ public class FirewallTest extends FloodlightTestCase {
         rule.nw_proto = IpProtocol.TCP;
         rule.any_nw_proto = false;
         // source is IP 192.168.1.2
-        rule.nw_src_prefix_and_mask = IPv4AddressWithMask.of("192.168.1.2/24");
+        rule.nw_src_prefix_and_mask = IPv4AddressWithMask.of("192.168.1.2/32");
         rule.any_nw_src = false;
         // dest is network 192.168.1.0/24
         rule.nw_dst_prefix_and_mask = IPv4AddressWithMask.of("192.168.1.0/24");
@@ -373,7 +373,7 @@ public class FirewallTest extends FloodlightTestCase {
         verify(sw);
 
         IRoutingDecision decision = IRoutingDecision.rtStore.get(cntx, IRoutingDecision.CONTEXT_DECISION);
-        assertEquals(decision.getRoutingAction(), IRoutingDecision.RoutingAction.FORWARD_OR_FLOOD);
+        assertEquals(IRoutingDecision.RoutingAction.FORWARD_OR_FLOOD, decision.getRoutingAction());
 
         // clear decision
         IRoutingDecision.rtStore.remove(cntx, IRoutingDecision.CONTEXT_DECISION);
@@ -383,7 +383,7 @@ public class FirewallTest extends FloodlightTestCase {
         verify(sw);
 
         decision = IRoutingDecision.rtStore.get(cntx, IRoutingDecision.CONTEXT_DECISION);
-        assertEquals(decision.getRoutingAction(), IRoutingDecision.RoutingAction.DROP);
+        assertEquals(IRoutingDecision.RoutingAction.DROP, decision.getRoutingAction());
     }
 
     @Test
-- 
GitLab