import numpy as np
import lib.pyplotdefs as pd
import sys

trials = 100000
p = 1.0
s = 1 - p

pdf = lambda p, s, th: (
    p * 3 / 8 / np.pi * np.sin(th)**2
    + s * 3 / 16 / np.pi * (1 + np.cos(th)**2)
)

cdf = lambda p, s, th: (
    0.5
    - (2 * p + s) * 3 / 8 * np.cos(th)
    + (2 * p - s) / 8 * np.cos(th)**3
)

xplot = np.linspace(0.0, np.pi, 1000)
pdfplot = pdf(p, s, xplot)
cdfplot = cdf(p, s, xplot)

# (pd.Plotter()
#     .plot(xplot, pdfplot)
#     .plot(xplot, cdfplot)
#     .ggrid()
#     .show()
#     .close()
# )
# sys.exit(0)

def newton_raphson(p: float, s: float, r: float) -> float:
    th = r * np.pi
    th = np.pi / 2
    dth = np.inf
    for _ in range(1000):
        dth = (cdf(p, s, th) - r) / (pdf(p, s, th) * 2 * np.pi * np.sin(th))
        th -= dth
        if abs(dth) < 1e-6:
            return th
        # th = min(max(th, 1e-6), np.pi * (1 - 1e-6))
    (pd.Plotter()
        .plot(xplot, pdfplot)
        .plot(xplot, cdfplot)
        .plot(th, r, marker="o", linestyle="", color="r")
        .show()
        .close()
    )
    raise Exception(f"didn't converge: {p=:.3f}; {s=:.3f}; {r=:.3f}")

data = np.zeros((trials, 2), dtype=np.float64)
for k in range(trials):
    print(f"\r  {k = :5.0f} ", end="", flush=True)
    r = np.random.random()
    th = newton_raphson(p, s, r)
    data[k, :] = (th, r)
print()

(pd.Plotter()
    .plot(
        xplot, cdfplot,
        marker="", linestyle="-", color="k"
    )
    .plot(
        data[:, 0], data[:, 1],
        marker="o", linestyle="", color="C0", alpha=0.35
    )
    .ggrid()
    .show()
    .close()
)