From a530c4f442073624b98549a9f322c44a13aa96b7 Mon Sep 17 00:00:00 2001
From: Sepehr Madani <ssepehrmadani@gmail.com>
Date: Mon, 14 Sep 2020 22:46:15 -0400
Subject: [PATCH] Resolve compatibility issues with other algorithms

---
 algorithms/butterfly_algorithm.py | 49 ++++++++++++++++++-------------
 1 file changed, 28 insertions(+), 21 deletions(-)

diff --git a/algorithms/butterfly_algorithm.py b/algorithms/butterfly_algorithm.py
index b767371..0ae95ee 100644
--- a/algorithms/butterfly_algorithm.py
+++ b/algorithms/butterfly_algorithm.py
@@ -25,27 +25,36 @@ class ButterflyAlgorithm(BaseAlgorithm):
         super().check_parameters()
         assert len(self.null_degrees) == 1
     
+    def get_weights(self):
+        weights = [exp(1j * x) for x in self.vector_changes]
+        return weights
+    
+    def get_final_weights(self):
+        weights = [exp(1j * (x+self.alpha/2)) for x in self.vector_changes]
+        return weights
+
     def update_pattern(self):
         pattern_values = compute_pattern(
             N=self.N,
             k=self.k,
-            weights=[exp(1j * -x) for x in self.vector_changes],
+            weights=self.get_weights(),
             degrees=self.null_degrees,
             use_absolute_value=False
         )
-        self.pattern = pattern_values[0].conjugate()
+        self.pattern = pattern_values[0]
 
     def solve(self):
-        alpha = (2*pi) / (2**self.bit_resolution)
+        self.alpha = (2*pi) / (2**self.bit_resolution)
 
         self.null_deg = self.null_degrees[0]
         self.theta = pi * cos(radians(self.null_deg))
 
-        self.vector_dirs = [wrapToPi(k * self.theta) for k in range(self.N)]
+        self.vector_dirs = [wrapToPi(-k * self.theta) for k in range(self.N)]
         self.sum_dir = wrapToPi(phase(sum([exp(1j* x) for x in self.vector_dirs])))
 
         self.vector_changes = [0.0] * self.N
-        self.vector_change_limit = (pi * (2**self.bit_count - 1) / 2**self.bit_resolution)
+        self.vector_change_limit_pos = self.alpha * (2**self.bit_count-2) / 2
+        self.vector_change_limit_neg = -self.alpha * (2**self.bit_count) / 2
         
         self.update_pattern()
 
@@ -58,12 +67,12 @@ class ButterflyAlgorithm(BaseAlgorithm):
                 other = self.N - 1 - idx # the other vector, symmteric to vector[idx] wrt sum
                 angle_with_sum = wrapToPi(self.vector_dirs[idx] + self.vector_changes[idx] - self.sum_dir)
                 
-                if 0 <= angle_with_sum < pi - alpha/2: # the vector is after the sum direction (left half)
-                    self.vector_changes[idx] += alpha
-                    self.vector_changes[other] -= alpha
-                elif -pi + alpha/2 < angle_with_sum <= 0: # the vector is before the sum direction (right half)
-                    self.vector_changes[idx] -= alpha
-                    self.vector_changes[other] += alpha
+                if 0 <= angle_with_sum < pi - self.alpha/2: # the vector is after the sum direction (left half)
+                    self.vector_changes[idx] += self.alpha
+                    self.vector_changes[other] -= self.alpha
+                elif -pi + self.alpha/2 < angle_with_sum <= 0: # the vector is before the sum direction (right half)
+                    self.vector_changes[idx] -= self.alpha
+                    self.vector_changes[other] += self.alpha
 
                 self.normalize_change_vector(idx)
                 self.normalize_change_vector(other)
@@ -78,15 +87,13 @@ class ButterflyAlgorithm(BaseAlgorithm):
                         self.vector_changes = original_vector_changes[:]
                         self.update_pattern()
 
-        # print(f'\nFinal pattern value: {abs(self.pattern) = }'
-            #   f'\nFinal score: {-20 * log10(abs(self.pattern))}')
-        return -20 * log10(abs(self.pattern))
+        return (
+            self.get_final_weights(),
+            -20 * log10(abs(self.pattern))
+        )
 
     def normalize_change_vector(self, idx):
-        limit = self.vector_change_limit
-        
-        if self.vector_changes[idx] > 0:
-            self.vector_changes[idx] = min(limit, self.vector_changes[idx])
-
-        if self.vector_changes[idx] < 0:
-            self.vector_changes[idx] = max(-limit, self.vector_changes[idx])
+        limit_pos = self.vector_change_limit_pos
+        limit_neg = self.vector_change_limit_neg
+        self.vector_changes[idx] = min(limit_pos, self.vector_changes[idx])
+        self.vector_changes[idx] = max(limit_neg, self.vector_changes[idx])
-- 
GitLab