# Robert Singleton
# 3/16/2023
#
# v2
# - python package
#
# v3
# - python package
#
# v4
# truncation
#
# v5
# tries
#
# v6
# qubit recycle
#
# vv1
# This is version vv1 of qubit recycling code. I am
# keeping the routine run_shor_recycle() local to
# this file. Eventually it will be placed in shor_util.py.
#
# vv1
# does not work.
#
# vv2
# 


# functionality
import numpy as np
import matplotlib.pyplot as plt

from shor_util import c_U1, check_solution, plot_phase_hist

# importing Qiskit
from qiskit import IBMQ, Aer, transpile, assemble
from qiskit import QuantumCircuit, ClassicalRegister, QuantumRegister

# import basic plot tools and circuits
from qiskit.visualization import plot_histogram
import matplotlib.pyplot as plt

# shor local
#from shor_util import run_shor_recycle

# qubit recycling
def run_shor_recycle(N, a, L, M, filename, u_ver, trnc_lv): 

    print("** run_shor_recycle ...")
    print("** trnc_lv:", trnc_lv)
    n_control = 1    
    n_work = L

    # initialize counts dictionary
    mycounts = {}
    for r in range(2**M):
        r_bin = bin(r)[2:].zfill(M)
        mycounts[str(r_bin)] = 0

    # start iterations
    max_it = 100
    for it in range(max_it):
        print("**xx iteration:", it)
        # initialize classical bits
        cbit = ['0'] * M

        # Quantum circuit
        #print("** construct quantum circuit ...")    
        m = QuantumRegister(n_control, name='m')
        w = QuantumRegister(n_work, name='w')
        cl  = ClassicalRegister(n_control, name='cl')
        qc = QuantumCircuit(m, w, cl)
        
        # 1st iteration
        ###############
        # initialize control register to 0
        qc.initialize([1, 0], m)
    
        # Apply H to control register
        qc.h(0)
    
        # Work register in state |1>
        qc.x(n_control) # w0=1, w1=0, ..., w(n_work-1)=0

        power = 2**(M-1)    
        qc.append(c_U1(N, a, n_work, power, u_ver, trnc_lv, barrier=False), 
                  [0] + [i+n_control for i in range(n_work)])

        qc.h(0)
        qc.measure(m, cl)
        #qc.draw()
        #plt.show()

        aer_sim = Aer.get_backend('aer_simulator')
        t_qc = transpile(qc, aer_sim)
        obj = assemble(t_qc)
        results = aer_sim.run(obj, shots=1).result()
        counts = results.get_counts()
        cbit[0] = counts.most_frequent()
        #print("**xx cbit:", cbit)

        # 2nd iteration
        ###############
        # initialize control register to 0
        qc.initialize([1, 0], m)
    
        # Apply H to control register
        qc.h(0)
    
        power = 2**(M-2)    
        qc.append(c_U1(N, a, n_work, power, u_ver, trnc_lv, barrier=False), 
                  [0] + [i+n_control for i in range(n_work)])

        qc.h(0)
        if int(cbit[0]) == 1: qc.p(2*np.pi/2, 0)
        
        qc.measure(m, cl)
        #qc.draw()
        #plt.show()

        aer_sim = Aer.get_backend('aer_simulator')
        t_qc = transpile(qc, aer_sim)
        obj = assemble(t_qc)
        results = aer_sim.run(obj, shots=1).result()
        counts = results.get_counts()
        cbit[1] = counts.most_frequent()
        #print("**xx cbit:", cbit)

        # 3rd iteration
        ###############    
        # initialize control register to 0
        qc.initialize([1, 0], m)
    
        # Apply H to control register
        qc.h(0)
    
        power = 2**(M-3)    
        qc.append(c_U1(N, a, n_work, power, u_ver, trnc_lv, barrier=False), 
                  [0] + [i+n_control for i in range(n_work)])

        qc.h(0)
        if int(cbit[0]) == 1: qc.p(2*np.pi/4, 0)
        if int(cbit[1]) == 1: qc.p(2*np.pi/2, 0)
    
        qc.measure(m, cl)
        #qc.draw()
        #plt.show()

        aer_sim = Aer.get_backend('aer_simulator')
        t_qc = transpile(qc, aer_sim)
        obj = assemble(t_qc)
        results = aer_sim.run(obj, shots=1).result()
        counts = results.get_counts()
        cbit[2] = counts.most_frequent()
        #print("**xx cbit:", cbit)

        # 4rd iteration
        ###############    
        # initialize control register to 0
        qc.initialize([1, 0], m)
    
        # Apply H to control register
        qc.h(0)
    
        power = 2**(M-4)    
        qc.append(c_U1(N, a, n_work, power, u_ver, trnc_lv, barrier=False), 
                  [0] + [i+n_control for i in range(n_work)])

        qc.h(0)
        if int(cbit[0]) == 1: qc.p(2*np.pi/8, 0)
        if int(cbit[1]) == 1: qc.p(2*np.pi/4, 0)
        if int(cbit[2]) == 1: qc.p(2*np.pi/2, 0)    
    
        qc.measure(m, cl)
        #qc.draw()
        #plt.show()

        aer_sim = Aer.get_backend('aer_simulator')
        t_qc = transpile(qc, aer_sim)
        obj = assemble(t_qc)
        results = aer_sim.run(obj, shots=1).result()
        counts = results.get_counts()
        cbit[3] = counts.most_frequent()
        #print("**xx cbit:", cbit)

        # 5th iteration
        ###############    
        # initialize control register to 0
        qc.initialize([1, 0], m)
    
        # Apply H to control register
        qc.h(0)
        
        power = 2**(M-5)
        qc.append(c_U1(N, a, n_work, power, u_ver, trnc_lv, barrier=False), 
                  [0] + [i+n_control for i in range(n_work)])

        qc.h(0)
        if int(cbit[0]) == 1: qc.p(2*np.pi/16, 0)
        if int(cbit[1]) == 1: qc.p(2*np.pi/8, 0)
        if int(cbit[2]) == 1: qc.p(2*np.pi/4, 0)
        if int(cbit[3]) == 1: qc.p(2*np.pi/2, 0)        

        qc.measure(m, cl)
        #qc.draw()
        #plt.show()

        aer_sim = Aer.get_backend('aer_simulator')
        t_qc = transpile(qc, aer_sim)
        obj = assemble(t_qc)
        results = aer_sim.run(obj, shots=1).result()
        counts = results.get_counts()
        cbit[4] = counts.most_frequent()
        #print("**xx cbit:", cbit)

        # convert cbit array to l_bin string
        l_bin = ""
        for s in cbit[::-1]:
            l_bin = l_bin + s
        mycounts[l_bin] += 1
            
        # check for solution
        count = 1
        verbose = False
        sol = check_solution(l_bin, count, N, a, filename, verbose)
        print("**xx l_bin, sol:", l_bin, sol)        
        print()

    # strip zero elements from mycounts dictionary
    for r in range(2**M):
        r_bin = bin(r)[2:].zfill(M)
        if mycounts[r_bin] == 0: del mycounts[r_bin]

    # plot phase histogram
    width = 0.1
    plot_phase_hist(mycounts, N, a, u_ver, trnc_lv, width, filename, verbose)
    plt.show()
    
######### parameters #########

# set version
ver = 6

# run factoring circuit
verbose = False
very_quiet = False

# set parameters

N = 21
a = 2 # r=6
r = 6

# set register sizes
L = int(np.ceil(np.log2(N))) # working qubits
M = 5 # iterations for control register
      # m=5 => power_max=16

# MEO version
# full agreement between versions 0 and 1.
#u_ver = 0 # U^p is a strict concatenation of U1's
u_ver = 1 # U^p gates 2-cycles
#u_ver = 2 # truncation
trnc_lv = 0 # truncation level: 0, 1, ..., r-1; 0=no truncation

# set filename prefix
if u_ver != 2: trnc_lv = "x"

######### start calculation ##########        
filename = "03_shor_v{2}_N{0}_a{1}_m{5}_uver{3}_{4}_recycle".format(N, a, ver, u_ver, trnc_lv, M)

print("** N, a, L, M, u_ver, trnc_lv:", N, a, L, M, u_ver, trnc_lv)
run_shor_recycle(N, a, L, M, filename, u_ver, trnc_lv)

    
