# Robert Singleton
# 2/25/2023
#
# Plots f(x) = a^x mod N vs x for diagnostics.
#
# v1
# - 09-shor1_period_v1.py
# - From 02-period_N33_v1.py
#
# 13-code
# v2
# - python package
#
# v3
# - python package
#
# v4
# truncation
#
# v5
# tries

import matplotlib.pyplot as plt
import numpy as np
from numpy.random import randint

# shor local
from shor_util import *

######### subroutines ######### 

# find factors of N given base a and period r
def find_factors(N, a, r):
    b = a**int(r/2)
    modp = (b+1) % N
    modm = (b-1) % N
    modr = b**r % N
    print("** a, r, b, b+1 modN, b-1 modN, b^r:",
          a, r, b, modp, modm, modr)
    f1 = np.gcd(b+1, N)
    f2 = np.gcd(b-1, N)    
    print("** factor1", f1)
    print("** factor2", f2)
    return f1, f2

##################

ver = 6

N = 21 # 3*7
a = 2 # r=6
a = 4 # r=3
a = 5 # r=6
a = 8 # r=2 <==x look at this example, prediction: two peaks
a = 10 # r=6
a = 11 # r=6
a = 13 # r=2
a = 16 # r=3
a = 17 # r=6
a = 19 # r=6
a = 20 # r=2


# parameters
#N = 33
#a = 7    # r = 10 => 3*11 - implemented - hard
#a = 10   # r = 2  => 3*11 - implemented in v1 - simple <==x
#a = 2    # r = 10 => trivial root, 1*33
#a = 4    # r = 5 and a is a perfect square
#a = 5    # r = 10 => 3*11
#a = 13   # r = 10 => 3*11
#a = 14   # r = 10 => 3*11
#a = 16   # r = 5 and a is a perfect square
#a = 17   # r = 10 => trivial root, 1*33
#a = 19   # r = 10 => 3*11
#a = 20   # r = 10 => 3*11
#etc.

#N = 35 # 5*7
#a = 4 # r=6 - implemented

#N = 21 # 3*7
#a = 2 # r=6 - implemented

#N = 15 # 3*5 
#a = 2 # 2,4,7,8,11,13 r=2,4 - implemented

#N = 143 # 13*11
#a = 5 # r=20 - implemented
#a = 3 # r=15 no factors
#a = 4 # r=30
#a = 6 # r=60
#a = 7 # did not find solution
#a = 8 # r=20
## output:
##  a = 5
##  ** N, a, L: 143 5 8
##  ** a, r, b, b+1 mod N, b-1 mod N, b^r: 5 20 9765625 13 11 1
##  ** factor1 13
##  ** factor2 11
##  ** [1, 5, 25, 125, 53, 122, 38, 47, 92, 31, 12, 60, 14, 70, 64, 34,
##     27, 135, 103, 86]

#N = 221 # 13*17 L=8
#a = 2   # r=24 M=8

#N = 247 # 13*19 L=8
#a = 4   # r=18 M=8
#a = 2   # r=36 M=8
#a = 150

N = 15
a = 14

# calculate bit length of N
L = int(np.ceil(np.log2(N)))

# print parameters
print("** N, a, L:", N, a, L)

# set filename prefix
filename = "14_shor_v{2}_period_N{0}_a{1}".format(N, a, ver)

# calculate the plotting data
# set scale
Nx = N
if (N == 33 and a == 7): Nx = 23
if (N == 33 and a == 10): Nx = 18
if (N == 143 and a == 4): Nx = 50
if (N == 221 and a == 2): Nx = 48
if (N == 247 and a == 4): Nx = 36
if (N == 247 and a == 2): Nx = 55
# fill arrays
xvals = np.arange(Nx)
yvals = [np.mod(a**x, N) for x in xvals]

# plot a^x mod N vs x and find period r
fig, ax = plt.subplots()
ax.plot(xvals, yvals, linewidth=1, linestyle='dotted', marker='x')
try: # plot r on the graph
    r = yvals[1:].index(1) + 1 
    plt.annotate('', xy=(0,1), xytext=(r,1), arrowprops=dict(arrowstyle='<->'))
    plt.annotate('$r=%i$' % r, xy=(r/3,1.5))
except ValueError:
    print('Could not find period, check a < N with no common factors.')

# factors
f1, f2 = find_factors(N, a, r)

# add factors f1 and f2 to figure title
ax.set(xlabel='$x$', ylabel='${1}^x$ mod ${0}$'.format(N, a),
       title="N = {0}  a = {1}  $f_1={2}$  $f_2={3}$".format(N, a, f1, f2))

# print a^x mod N vs x values
print("**", yvals)

# save figure and plot
plt.savefig(filename+".jpg")
plt.show()


