diff --git a/Multiplicative Algorithm/cSNMF.py b/Multiplicative Algorithm/cSNMF.py index 4318099366f9c60356ab281eb2b5a7bd3a42cc01..d56b90a4fc48d69d22f0e7f6a9cf09d0e07f4240 100644 --- a/Multiplicative Algorithm/cSNMF.py +++ b/Multiplicative Algorithm/cSNMF.py @@ -90,6 +90,7 @@ def sort_WH(W, H): def factorize(data_array, + init_W = None, rank = config.RANK, beta = None, threshold = 0.5, @@ -98,7 +99,8 @@ def factorize(data_array, seed_H = None, log = logger, debug = False, - axing = True): + axing = True, + update_W = True): log.info('Rank= %s, Threshold= %s', rank, threshold) @@ -136,8 +138,12 @@ def factorize(data_array, W_shape = (D.shape[0], rank) H_shape = (rank, D.shape[1]) - np.random.seed(seed_W) - W = np.random.uniform(low = 0.01, high = 1., size = W_shape) + if init_W is None: + np.random.seed(seed_W) + W = np.random.uniform(low = 0.01, high = 1., size = W_shape) + else: + W = init_W + np.random.seed(seed_H) H = np.random.uniform(low = 0.01, high = 1., size = H_shape) log.info('W, H chosen') @@ -158,7 +164,10 @@ def factorize(data_array, if iterations > max_iter: break - W_new = update_W(W, H) + if update_W: + W_new = update_W(W, H) + else: + W_new = W H_new = update_H(W_new, H) # Check for nonnegativity