{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Run cSNMF on Trips data using multiplicative update rules\n", "\n", "**Constraint:** L1 norm of columns of W should be 1\n", "\n", "Get a copy of the data matrices in your local machine from the following links:\n", " - https://uofi.box.com/s/yo60oe084d68obohgraek4weuaqtpgsp" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from __init__ import *\n", "import numpy as np\n", "import pandas as pd\n", "import config\n", "import cSNMF\n", "import matplotlib.pyplot as plt\n", "%matplotlib inline" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "## Read Full-Link data and prep for running NMF.\n", "D = np.loadtxt('D_trips.txt')\n", "logger.info('Full_link data has been read')\n", "\n", "if config.SEEDED == 1:\n", " seed_W = 0; seed_H = 1\n", "elif config.SEEDED == 0:\n", " seed_W = None; seed_H = None\n", "else:\n", " logger.critical('Seed value invalid. Needs to be 0 or 1. Check config.py!')\n", " quit()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "W, H, results = cSNMF.factorize(D,\n", " beta = 5000,\n", " rank = config.RANK,\n", " max_iter = 600,\n", " seed_W = seed_W,\n", " seed_H = seed_H,\n", " debug = True,\n", " axing = True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "results" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#np.savetxt('W_(seed_W = 10,seed_H = 21).txt', W)\n", "#np.savetxt('H_(seed_W = 10,seed_H = 21).txt', H)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.3" } }, "nbformat": 4, "nbformat_minor": 2 }