##### shor utilities #####

# functionality
import contfrac
import numpy as np

# importing Qiskit
from qiskit import IBMQ, Aer, transpile, assemble
from qiskit import QuantumCircuit, ClassicalRegister, QuantumRegister

# import basic plot tools and circuits
from qiskit.visualization import plot_histogram
from qiskit.circuit.library import QFT
import matplotlib.pyplot as plt

# import ME operators
from shor_meo_N15_ax import *
from shor_meo_N21_a2 import *
#from shor_meo_N21_a13 import *
from shor_meo_N33_a7 import *
from shor_meo_N35_a4 import *
from shor_meo_N143_a5 import *
from shor_meo_N247_a2 import *
#from shor_meo_N247_a4 import *

# ME operator plumbing
#
# controlled modular exponentiation operator U_a^power
# barriers included for circuit plotting
def c_U1(N, a, L, power, u_ver, trnc_lv, barrier=True):
    U = QuantumCircuit(L)
    UME(U, N, a, power, u_ver, trnc_lv, barrier)
    U = U.to_gate()
    #U.name = "{1}^{2} mod {0}".format(N, a, power)
    U.name = "\n\n" + "  " + "U^{2}".format(N, a, power)
    c_U = U.control()
    return c_U

# interface with N and a
# U2, U4, U8, ...
def U512(U, N, a, u_ver, trnc_lv,  barrier=False):
    if N == 143 and a == 5:
        U512_N143_a5(U, u_ver, trnc_lv,  barrier)
    if N == 247 and a == 2:
        U512_N247_a2(U, u_ver, trnc_lv,  barrier)

def U256(U, N, a, u_ver, trnc_lv,  barrier=False):
    if N == 143 and a == 5:
        U256_N143_a5(U, u_ver, trnc_lv,  barrier)
    if N == 247 and a == 2:
        U256_N247_a2(U, u_ver, trnc_lv,  barrier)

def U128(U, N, a, u_ver, trnc_lv,  barrier=False):
    if N == 143 and a == 5:
        U128_N143_a5(U, u_ver, trnc_lv,  barrier)
    if N == 247 and a == 2:
        U128_N247_a2(U, u_ver, trnc_lv,  barrier)
    if N == 247 and a == 4:
        U128_N247_a4(U, u_ver, trnc_lv,  barrier)

def U64(U, N, a, u_ver, trnc_lv,  barrier=False):
    if N == 143 and a == 5:
        U64_N143_a5(U, u_ver, trnc_lv,  barrier)
    if N == 247 and a == 2:
        U64_N247_a2(U, u_ver, trnc_lv,  barrier)
    if N == 247 and a == 4:
        U64_N247_a4(U, u_ver, trnc_lv,  barrier)

def U32(U, N, a, u_ver, trnc_lv,  barrier=False):
    if N == 21 and a == 2:
        U32_N21_a2(U, u_ver, trnc_lv, barrier)
    if N == 21 and a == 13:
        U32_N21_a13(U, u_ver, trnc_lv,  barrier)
    if N == 33 and a == 7:
        U32_N33_a7(U, u_ver, trnc_lv,  barrier)
    if N == 143 and a == 5:
        U32_N143_a5(U, u_ver, trnc_lv,  barrier)
    if N == 247 and a == 2:
        U32_N247_a2(U, u_ver, trnc_lv,  barrier)
    if N == 247 and a == 4:
        U32_N247_a4(U, u_ver, trnc_lv,  barrier)
 
def U16(U, N, a, u_ver, trnc_lv, barrier=False):
    if N == 15:
        U16_N15(U, a, u_ver, trnc_lv,  barrier)
    if N == 21 and a == 2:
        U16_N21_a2(U, u_ver, trnc_lv, barrier)
    if N == 21 and a == 13:
        U16_N21_a13(U, u_ver, trnc_lv,  barrier)
    if N == 33 and a == 7:
        U16_N33_a7(U, u_ver, trnc_lv,  barrier)
    if N == 35 and a == 4:
        U16_N35_a4(U, u_ver, trnc_lv,  barrier)
    if N == 143 and a == 5:
        U16_N143_a5(U, u_ver, trnc_lv,  barrier)
    if N == 247 and a == 2:
        U16_N247_a2(U, u_ver, trnc_lv,  barrier)
    if N == 247 and a == 4:
        U16_N247_a4(U, u_ver, trnc_lv,  barrier)

def U8(U, N, a, u_ver, trnc_lv, barrier=False):
    if N == 15:
        U8_N15(U, a, u_ver, trnc_lv,  barrier)
    if N == 21 and a == 2:
        U8_N21_a2(U, u_ver, trnc_lv, barrier)
    if N == 21 and a == 13:
        U8_N21_a13(U, u_ver, trnc_lv,  barrier)
    if N == 33 and a == 7:
        U8_N33_a7(U, u_ver, trnc_lv,  barrier)
    if N == 35 and a == 4:
        U8_N35_a4(U, u_ver, trnc_lv,  barrier)
    if N == 143 and a == 5:
        U8_N143_a5(U, u_ver, trnc_lv,  barrier)
    if N == 247 and a == 2:
        U8_N247_a2(U, u_ver, trnc_lv,  barrier)
    if N == 247 and a == 4:
        U8_N247_a4(U, u_ver, trnc_lv,  barrier)

def U4(U, N, a, u_ver, trnc_lv, barrier=False):
    if N == 15:
        U4_N15(U, a, u_ver, trnc_lv,  barrier)
    if N == 21 and a == 2:
        U4_N21_a2(U, u_ver, trnc_lv, barrier)
    if N == 21 and a == 13:
        U4_N21_a13(U, u_ver, trnc_lv,  barrier)
    if N == 33 and a == 7:
        U4_N33_a7(U, u_ver, trnc_lv,  barrier)        
    if N == 35 and a == 4:
        U4_N35_a4(U, u_ver, trnc_lv,  barrier)
    if N == 143 and a == 5:
        U4_N143_a5(U, u_ver, trnc_lv,  barrier)
    if N == 247 and a == 2:
        U4_N247_a2(U, u_ver, trnc_lv,  barrier)
    if N == 247 and a == 4:
        U4_N247_a4(U, u_ver, trnc_lv,  barrier)
     
def U2(U, N, a, u_ver, trnc_lv, barrier=False):
    if N == 15:
        U2_N15(U, a, u_ver, trnc_lv,  barrier)
    if N == 21 and a == 2:
        U2_N21_a2(U, u_ver, trnc_lv, barrier)
    if N == 21 and a == 13:
        U2_N21_a13(U, u_ver, trnc_lv,  barrier)
    if N == 33 and a == 7:
        U2_N33_a7(U, u_ver, trnc_lv,  barrier)
    if N == 35 and a == 4:
        U2_N35_a4(U, u_ver, trnc_lv,  barrier)
    if N == 143 and a == 5:
        U2_N143_a5(U, u_ver, trnc_lv,  barrier)
    if N == 247 and a == 2:
        U2_N247_a2(U, u_ver, trnc_lv,  barrier)
    if N == 247 and a == 4:
        U2_N247_a4(U, u_ver, trnc_lv,  barrier)

def U1(U, N, a, u_ver, trnc_lv, barrier=False):
    if N == 15:
        U1_N15(U, a, u_ver, trnc_lv,  barrier)
    if N == 21 and a == 2:
        U1_N21_a2(U, u_ver, trnc_lv, barrier)
    if N == 21 and a == 13:
        U1_N21_a13(U, u_ver, trnc_lv,  barrier)
    if N == 33 and a == 7:
        U1_N33_a7(U, u_ver, trnc_lv,  barrier)
    if N == 35 and a == 4:
        U1_N35_a4(U, u_ver, trnc_lv,  barrier)
    if N == 143 and a == 5:
        U1_N143_a5(U, u_ver, trnc_lv,  barrier)
    if N == 247 and a == 2:
        U1_N247_a2(U, u_ver, trnc_lv,  barrier)
    if N == 247 and a == 4:
        U1_N247_a4(U, u_ver, trnc_lv,  barrier)

# master ME operator
def UME(U, N, a, power, u_ver, trnc_lv, barrier=False):
    if power == 1:
        U1(U, N, a, u_ver, trnc_lv,  barrier)
    elif power == 2:
        U2(U, N, a, u_ver, trnc_lv,  barrier)
    elif power == 4:
        U4(U, N, a, u_ver, trnc_lv,  barrier)
    elif power == 8:
        U8(U, N, a, u_ver, trnc_lv,  barrier)
    elif power == 16:
        U16(U, N, a, u_ver, trnc_lv,  barrier)
    elif power == 32:
        U32(U, N, a, u_ver, trnc_lv,  barrier)
    elif power == 64:
        U64(U, N, a, u_ver, trnc_lv,  barrier)
    elif power == 128:
        U128(U, N, a, u_ver, trnc_lv,  barrier)
    elif power == 256:
        U256(U, N, a, u_ver, trnc_lv,  barrier)
    elif power == 512:
        U512(U, N, a, u_ver, trnc_lv,  barrier)
    return U

# classical definition of U(a^power * x)
def U_cl_orig(x, a, power, N, L):
    amod = x
    for iteration in range(power):    
        amod = int(np.mod(a * amod, N))
    amod_bin = bin(amod)[2:].zfill(L)
    return amod, amod_bin

# classical definition of U(a^power * x)
def U_cl(x, a, power, N, L):
    amod = int(np.mod(a**power * x, N))
    amod_bin = bin(amod)[2:].zfill(L)
    return amod, amod_bin

# draw ME operator with figure captions
def plot_U(N, a, L, power, u_ver, trnc_lv, filename):
    import qiskit.qasm3 as mydump
    w = QuantumRegister(L, name='w')
    U = QuantumCircuit(w)
    UME(U, N, a, power, u_ver, trnc_lv, barrier=True)
    # ## write QASM3 code **xx
    # f = open("openQASM_gate.qasm", "w")
    # mydump.dump(U, f)
    # f.close()
    # ##
    #fig = plt.figure(figsize = (10, 5))
    fig = plt.figure(figsize = (20, 10))    
    if N > 100:
        fig.set_figwidth(20)    
        fig.set_figheight(5)
    ax = fig.add_subplot()
    fold = 65
    if N > 100:
        fold = 100
        fold = 70 # **xx    
    # fold=60 is nominal, fold=100 for large N
    if trnc_lv == "x":
        ax.text(0.2, 0.8, 'N={0} a={1} power={2} u_ver={3}'.format(
            N, a, power, u_ver), size=14)
    else:
        ax.text(0.2, 0.8, 'N={0} a={1} power={2} u_ver={3} trnc_lv={4}'.format(
            N, a, power, u_ver, trnc_lv), size=14)
    U.draw(ax=ax, fold=fold, plot_barriers=True)
    plt.savefig(filename+"_U{0}.jpg".format(power), bbox_inches='tight')
    #plt.show()
    return U

# bin_str="0111" : w3=0, w2=1, w1=1, w0=1
def init(U, bin_st, L):
    bin_st = bin_st[::-1]
    for q in range(L):
        if bin_st[q] == "0":
            U.id(q)
        if bin_st[q] == "1":
            U.x(q)
    return U

# convert list to string
def listToString(s):
    str1 = ""
    return (str1.join(s))

def find_gate(x_targ, x_gate, stack):

    # return if x_gate is zero
    if x_gate == 0:
        return

    # create target bit register
    t_bits = []
    for n in range(len(x_gate)):
        if x_targ[n] != x_gate[n]:
            t_bits.append(n)
    print()
    print("** t_bits", t_bits)

    # create stencil with target bits
    stencil = ["-" for _ in range(len(x_targ))]
    for n in t_bits:
        stencil[n] = "x"
    stencil_str = listToString(stencil)
        
    # print stencil with target and gate bits
    print("** stencil ", stencil_str)
    print("** x_gate  ", x_gate,"<")    
    print("** x_target", x_targ)
    print()

    # print gate bit with stack
    print("** stencil ", stencil_str)    
    print("** x_gate  ", x_gate,"<")
    n = 0
    for stack0 in stack:
        n +=1
        print("** stack {:02d}".format(n), stack0)        
    print()
    
    # read stencil from left to right adding control bits
    c_bits = []
    stencil_exhausted = False    
    for n in range(len(stencil))[::-1]:
        print("iteration number", len(stencil) - n )
        stack_remove = []
        stencil_change = False
        if stencil[n] == "-":
            for stack0 in stack:
                if stack0[n] != x_gate[n]:
                    stack_remove.append(stack0)
                    c_bits.append(n)
                    stencil_change = True
            if stencil_change:
                stencil[n] = "."
            else:
                stencil[n] = "+"
            stencil_str = listToString(stencil)
            for st in stack_remove:
                stack.remove(st)
                                
            c_bits = list(set(c_bits))            
            print("** pruned stack")
            print("** c_bits", c_bits)
            print("** t_bits", t_bits)                    
            print(stencil_str)    
            print(x_gate,"<")    
            for stack0 in stack:
                print(stack0)
            print()

    while stack != []:
        print("** stack", stack, len(stack))
        print()
        print("** c_bits", c_bits)        
        print("** t_bits", t_bits)        
        print(stencil_str)    
        print(x_gate,"<")    
        for stack0 in stack:
            print(stack0)
        print()

        # right to left: convert target bits into control bits
        for n in range(len(stencil)):
            print("iteration number", len(stencil) - n )
            stack_remove = []
            stencil_change = False
            if stencil[n] == "x":
                for stack0 in stack:
                    if stack0[n] != x_gate[n]:
                        stack_remove.append(stack0)
                        stencil_change = True
                if stencil_change:
                    stencil[n] = "."                    
                    c_bits.append(n)
                    t_bits.remove(n)
                stencil_str = listToString(stencil)
                c_bits = list(set(c_bits))                            
                # prune stack
                for st in stack_remove:
                    stack.remove(st)

                # print state
                print("** pruned stack")
                print("** c_bits", c_bits)
                print("** t_bits", t_bits)                        
                print(stencil_str)    
                print(x_gate,"<")    
                for stack0 in stack:
                    print(stack0)
                print()
            
    # output circuit
    print("** writing output circuit")
    print("** paste into meo file, ")
    print("** e.g. into U256_N247_a2() in shor_meo_N247_a2.py")
    print()
    
    # initial U.x(c) when x_gate[c]=0
    for c in c_bits:
        if x_gate[c] == '0':
            rev = len(x_gate) - int(c) - 1
            print("U.x(", rev, ")")
    
    # place c_bits in qiskit order
    c_bits_qiskit = []
    for c in c_bits:
        rev = len(x_gate) - int(c) - 1
        c_bits_qiskit.append(rev)
    c_bits_qiskit = sorted(c_bits_qiskit)

    # place t_bits in qiskit order   
    t_bits_qiskit = []
    for t in t_bits:
        rev = len(x_gate) - int(t) - 1
        t_bits_qiskit.append(rev)
    t_bits_qiskit = sorted(t_bits_qiskit)

    # multi-control x-gates
    for t in t_bits_qiskit:
        if len(c_bits) == 1:
            print("U.cx(",c_bits_qiskit[0],",", t, ")")
        elif len(c_bits) == 2:
            print("U.ccx(",c_bits_qiskit[0],",", c_bits_qiskit[1],",", t, ")")
        else:
            print("U.mct(",c_bits_qiskit,",", t, ")")
        
    # final U.x(c) when x_gate[c]=0
    for c in c_bits:
        if x_gate[c] == '0':
            rev = len(x_gate) - int(c) - 1
            print("U.x(", rev, ")")

    # print barrier
    print("if barrier: U.barrier()")

# run quantum cycle
def run_quantum_cycle(x_bin, a, L, N, power, u_ver,  trnc_lv):
    # define quantum circuit
    w = QuantumRegister(L, name='w')
    cl = ClassicalRegister(L, name='cl')
    qc = QuantumCircuit(w, cl)
    
    # init register
    init(qc, x_bin, L)
    qc.barrier()
    
    # append circuit
    UME(qc, N, a, power, u_ver, trnc_lv)
    qc.barrier()
    qc.measure(w, cl)
    
    # simulate
    shots = 1
    aer_sim = Aer.get_backend('aer_simulator')
    t_qc = transpile(qc, aer_sim)
    obj = assemble(t_qc)
    results = aer_sim.run(obj, shots=shots).result()
    counts = results.get_counts()
    amod_q_bin = counts.most_frequent()
    return amod_q_bin
    
# MEO construction
def run_meo_sequence(N, a, L, M, power, x_list, filename, u_ver, trnc_lv,
                     check_quantum, verbose=False, very_quintet=False):
    # identify subroutine
    print("** run_meo_sequence ...")
    print()
    stack = []
    x_targ = 0
    x_gate = 0
    # #check_quantum = False # use this first to construct x_list sequence f(0), f(1), ...
    # check_quantum = True # then use this to contruct the ME gates
    for x in x_list:
        
        # calculate UME classically
        x_bin = bin(x)[2:].zfill(L)
        amod, amod_bin = U_cl(x, a, power, N, L) # a^power * x % N
        
        if check_quantum:
            amod_q_bin = run_quantum_cycle(x_bin, a, L, N, power, u_ver, trnc_lv)
        else:
            amod_q_bin = amod_bin
            
        if amod_bin != amod_q_bin:
            print("** cl =/= qm:", amod_bin, amod_q_bin)
        if verbose:
            plot_histogram(counts,
                           title='N = {0} a = {1}, power={2} u_ver={3}'.format(
                        N, a, power, u_ver))

        # output
        print("in:", "{:03d}".format(x), x_bin,
              "out:", "{:03d}".format(amod), amod_bin, amod_q_bin,
              "{:03d}".format(int(amod_q_bin,2)))

        # append to stack
        if amod_bin == amod_q_bin:
            stack.append(amod_q_bin)
        else:
            x_targ = amod_bin # correct value
            x_gate = amod_q_bin   # incorrect value
            break

    # construct next quantum gate
    if check_quantum:
        find_gate(x_targ, x_gate, stack)

        print()
        print("** sequence to date")
        for x in x_list:
            x_bin = bin(x)[2:].zfill(L)
            amod, amod_bin = U_cl(x, a, power, N, L)
            amod_q_bin = run_quantum_cycle(x_bin, a, L, N, power, u_ver, trnc_lv)            

            if amod_bin != amod_q_bin:
                print("** cl =/= qm:", amod_bin, amod_q_bin)
            # output
            print("in:", "{:03d}".format(x), x_bin,
                  "out:", "{:03d}".format(amod), amod_bin, amod_q_bin,
                  "{:03d}".format(int(amod_q_bin,2)))
            if amod_bin != amod_q_bin:
                break # what does this break do?
    
    # Plot UME circuits
    if not very_quintet:
        if (u_ver == 0 and power <= 4) or u_ver != 0:
            print("** plot ME operator for power=", power, "...")
            plot_U(N, a, L, power, u_ver, trnc_lv, filename)
            plt.show()

# inverse QFT
def qft_dagger(n):
    qc = QuantumCircuit(n)
    for qubit in range(n//2):
        qc.swap(qubit, n-qubit-1)
    for j in range(n):
        for m in range(j):
            qc.cp(-np.pi/float(2**(j-m)), m, j)
        qc.h(j)
    qc.name = "QFT†"
    return qc

# returns TRUE of string l_phi gives factors
def check_solution(l_bin, count, N, a, filename, verbose):
    # not a solution is the default
    sol = False

    # convert binary to dec    
    n = 0
    l_dec = 0
    for l in l_bin[::-1]:
        n += 1
        l_dec = l_dec + 2**(n-1) * int(l)

    # convert bin string to phase
    n = 0
    phi_tilde = 0
    for l in l_bin:
        n -= 1
        phi_tilde = phi_tilde + 2**n * int(l)

    # simple fraction generator
    res = len(str(phi_tilde)) - 2 # subtract 2 for "0."
    scale = 10**res # automated scale set by res
    num = int(phi_tilde*scale) 
    den = int(scale)
    phi = (num, den)
    # in lowest terms
    c = np.gcd(num, den) 
    num = int(num / c)
    den = int(den / c)
    phi = (num, den)
        
    # construct convergents for phi
    coefficients = list(contfrac.continued_fraction(phi))
    convergents = list(contfrac.convergents(phi))

    # check convergents for solution
    if verbose:
        print("")
        print("l_measured   :", l_bin, l_dec, "frequency:", count)
        print("phi_phase_bin:", "0."+l_bin)
        print("phi_phase_dec:", phi_tilde)
        print("phi_phase_frc:", phi)    
        print("cont frac of phi  :",coefficients)
        print("convergents of phi:", convergents)
    for conv in convergents:
        r = conv[1]
        test1 = r % 2 # 0 if r is even
        test2 = (a**int(r/2)-1) % N # 0 if a^r/2 is a trivial root
        test3 = (a**int(r/2)+1) % N # 0 if a^r/2 is a trivial root
        test4 = a**r % N # 1 if r is a solution
        if (test1==0 and test2!=0 and test3!=0 and test4==1):
            sol = True # solution found
            if verbose:
                print("conv:", conv, "r =", r, ": factors")
                print("factor1:", np.gcd(a**int(r/2)-1, N))
                print("factor2:", np.gcd(a**int(r/2)+1, N))
        else:
            if verbose:
                print("conv:", conv, "r =", r, ": no factors found")        
    return sol

# plot phase histogram with solutions in red
def plot_phase_hist(counts, N, a, u_ver, trnc_lv, width, filename, verbose): 

    # total count for normalization
    norm = 0
    for l_bin in counts.keys():
        norm += counts[l_bin]

    # construct histogram data and mark solutions in red
    keys_sort = sorted(counts.keys())
    counts_sort = []
    colors_sort = []
    for l_bin in keys_sort:
        count = counts[l_bin]
        counts_sort.append(count/float(norm))
        sol = check_solution(l_bin, count, N, a, filename, verbose)
        if sol:
            colors_sort.append('red')
            print("** found solution")            
        else:
            colors_sort.append('blue')            

    # creating the bar plot            
    fig = plt.figure(figsize = (10, 5))
    #fig = plt.figure(figsize = (20, 10)) # **xx
    plt.bar(keys_sort, counts_sort, color=colors_sort, width=width)
    if trnc_lv == "x":
        plt.title('N = {0} a = {1} u_ver={2}'.format(N, a, u_ver))
    else:
        plt.title('N = {0} a = {1} u_ver={2} trnc_lv={3}'.format(N, a, u_ver, trnc_lv))        
    plt.xticks(keys_sort, keys_sort, rotation='vertical')
    #plt.margins(0.2)
    plt.subplots_adjust(bottom=0.20)

# shor's algorithm                    
def run_shor_hist(N, a, L, M, filename, u_ver, trnc_lv, verbose=False,
                   very_quintet=False, amin_factor=1.0, amax=4000, shots=4096):
    # identify subroutine
    print("** run_shor_hist ...")
    print("** amin factor, amax: ", amin_factor, amax)
    
    # define circuit dimensions
    n_control = M
    n_work = L

    # plot UME operators
    if not very_quintet:
        print("** plot ME operators ...")
        max_power = n_control
        if u_ver == 0:
            max_power = 4        
            print("** truncate plotting at power =", 2*max_power)
        powers = [2**q for q in range(max_power)]
        for power in powers[::-1]:
            plot_U(N, a, L, power, u_ver, trnc_lv, filename)
        plt.show()
        
    # Quantum circuit
    print("** construct quantum circuit ...")    
    m = QuantumRegister(n_control, name='m')
    w = QuantumRegister(n_work, name='w')
    cl  = ClassicalRegister(n_control, name='cl')
    qc = QuantumCircuit(m, w, cl)

    # Apply H to control register
    for q in range(n_control):
        qc.h(q)
    
    # Work register in state |1>
    qc.x(n_control) # w0=1, w1=0, ..., w(n_work-1)=0

    # Controlled-U operations
    for q in range(n_control):
        qc.append(c_U1(N, a, n_work, 2**q, u_ver, trnc_lv, barrier=False), 
                  [q] + [i+n_control for i in range(n_work)])

    # Inverse-QFT
    qc.append(qft_dagger(n_control), range(n_control))
            
    # Measure circuit
    qc.measure(m, cl)

    # plot circuit
    if u_ver != 0:
        print("** plot quantum circuit ...")   
        qc.draw(fold=-1, scale=0.75)
        plt.savefig(filename+"_circuit.jpg")
        #plt.show()

    # simulate
    print("** run quantum circuit ...")    
    #shots = 4096
    aer_sim = Aer.get_backend('aer_simulator')
    t_qc = transpile(qc, aer_sim)
    obj = assemble(t_qc)
    results = aer_sim.run(obj, shots=shots).result()
    counts = results.get_counts() # phase histogram
    print("** counts:", counts)

    # plot histogram of counts with solutions in red
    print("** plot phase histogram")
    width = 0.2
    if N > 100:
        width = 0.7 # use 0.5=0.7 for larger N
        #width = 1.0 # use 0.5 for larger N **xx
        #width = 0.6
    plot_phase_hist(counts, N, a, u_ver, trnc_lv, width, filename, verbose)
    plt.savefig(filename+"_hist.jpg")
    plt.show()
    
    # prune counts dictionary: strip small elements
    counts_prune = {}
    # amin_factor and max are input parameters
    #amin_factor = 1.0 # nomial values 0.5-2.0
    amin = amin_factor*float(shots)/2**n_control  # proportional to shots/2**M
    print("")
    print("** amin_factor, amin, amax", amin_factor, amin, amax)
    for l_bin in counts.keys():
        if counts[l_bin] > amin and counts[l_bin] < amax:
            counts_prune[l_bin] = counts[l_bin]
    print("** counts_prune", counts_prune)

    # plot histogram of counts with solutions in red
    print("** plot phase histogram")
    width = 0.05 # nomial values 0.1-0.2
    if N > 100: # use 0.2 for larger N
        width = 0.2
    plot_phase_hist(counts_prune, N, a, u_ver, trnc_lv, width, filename, verbose)
    plt.savefig(filename+"_hist_prune.jpg")
    plt.show()

# check consistency of the U operator: passes test
def run_con_U(N, a, L, power, x_list, u_ver, f):

    print("** : check  consistence of U operator")
    print("** :", x_list)

    # choose x from x_list
    n = 3
    x = x_list[n] # decimal
    x_bin = bin(x)[2:].zfill(L) # binary
    print("** x, x_bin              :", x, x_bin)

    # classical calculation
    amod_dec, amod_bin = U_cl(x, a, power, N, L)
    print("** amod_dec, amod_bin    :", amod_dec, amod_bin)

    # quantum calculation
    amod_q_bin = run_quantum_cycle(x_bin, a, L, N, power, u_ver, trnc_lv)            
    amod_q_dec = int(amod_q_bin,2)
    print("** amod_q_dec, amod_q_bin:", amod_q_dec, amod_q_bin)    

# shor's algorithm by iteration
def run_shor_tries(N, a, L, M, filename, u_ver, trnc_lv, num_it): 

    f = open(filename+"_tries.txt", "w")
    f.write("** N, a, L, M, u_ver, trnc_lv: {0} {1} {2} {3} {4} {5}".format(N, a, L, M, u_ver, trnc_lv))
    f.write("\n")
    print("** trnc_lv:", trnc_lv)
    f.write("** trnc_lv: {0}\n".format(trnc_lv))    

    num_try_av = 0 # initialize

    # create quantum circuit
    n_control = M
    n_work = L
    
    # Quantum circuit
    #print("** construct quantum circuit ...")    
    m = QuantumRegister(n_control, name='m')
    w = QuantumRegister(n_work, name='w')
    cl  = ClassicalRegister(n_control, name='cl')
    qc = QuantumCircuit(m, w, cl)

    # Apply H to control register
    for q in range(n_control):
        qc.h(q)
    
    # Work register in state |1>
    qc.x(n_control) # w0=1, w1=0, ..., w(n_work-1)=0

    # Controlled-U operations
    for q in range(n_control):
        qc.append(c_U1(N, a, n_work, 2**q, u_ver, trnc_lv, barrier=False), 
                  [q] + [i+n_control for i in range(n_work)])

    # Inverse-QFT
    qc.append(qft_dagger(n_control), range(n_control))
            
    # Measure circuit
    qc.measure(m, cl)
    #qc.draw()
    #plt.show()

    # simulate
    aer_sim = Aer.get_backend('aer_simulator')
    t_qc = transpile(qc, aer_sim)
    obj = assemble(t_qc)

    num_tries = []
    for it in range(num_it):
        print()
        print("** it:", it)
        f.write(" \n")
        f.write("** it: {0}\n".format(it))
        sol = False
        num_try = 0
        while sol == False:
            num_try += 1        

            # simulate
            results = aer_sim.run(obj, shots=1).result()
            counts = results.get_counts()

            # check solution
            verbose = False
            for l_bin in counts.keys():
                sol = check_solution(l_bin, counts, N, a, filename, verbose)
                print("l_bin:", l_bin, sol)
                f.write("l_bin: {0} {1}\n".format(l_bin, sol))
                if sol:
                    print("number of tries:", num_try)
                    f.write("number of tries: {0}\n".format(num_try))
        num_tries.append(num_try)

    print()
    print("u_ver:", u_ver)
    print("trnc_lv:", trnc_lv)
    print("num_it:", num_it)
    print("num_tries:", num_tries)
    f.write("\n")
    f.write("u_ver: {0}\n".format(u_ver))
    f.write("trnc_lv: {0}\n".format(trnc_lv))
    f.write("num_it: {}\n".format(num_it))
    f.write("num_tries: {0}\n".format(num_tries))
    num_try_av = np.average(num_tries)
    num_try_sd = np.std(num_tries)
    print("num_try_av:", num_try_av)
    print("num_try_sd:", num_try_sd)
    print()    
    f.write("num_try_av: {0}\n".format(num_try_av))
    f.write("num_try_sd: {0}\n".format(num_try_sd))
    f.write("\n")
    f.close()
    return num_try_av, num_try_sd


# qubit recycling
def run_shor_recycle(N, a, L, M, filename, u_ver, trnc_lv, shots, verbose=False): 

    print("** run_shor_recycle ...")
    print("** trnc_lv:", trnc_lv)
    n_control = 1    
    n_work = L

    # Quantum circuit
    m = QuantumRegister(n_control, name='m')
    w = QuantumRegister(n_work, name='w')
    
    # classical registers
    cl_reg = list(range(M))
    for c in range(M):
        cl_name = 'cl' + str(c)
        cl_reg[c] = ClassicalRegister(1, name=cl_name)
    # use unpacking operator for classical registers
    qc = QuantumCircuit(m, w, *[cl_reg[i] for i in range(M)])
    
    # Work register in state |1>
    qc.x(1)

    # increment counter    
    for cnt in range(M):
            
        # initialize control register to 0
        qc.initialize([1, 0], m)

        # Apply H to control register
        qc.h(0)
        
        # apply ME operator
        power = 2**(M-cnt-1)    
        qc.append(c_U1(N, a, n_work, power, u_ver, trnc_lv, barrier=False), 
                      [0] + [i+n_control for i in range(n_work)])

        # decorate control register using feed-forward phase operators
        qc.h(0)
        for myn in range(cnt):
            qc.p(np.pi/2**(cnt-myn), 0).c_if(cl_reg[myn], 1)

        # measure control register
        qc.measure(m, cl_reg[cnt])
        qc.barrier()

    # draw circuit
    qc.draw(fold=50)
    plt.savefig(filename+'_circ.jpg')
    plt.show()

    # run  circuit
    aer_sim = Aer.get_backend('aer_simulator')
    t_qc = transpile(qc, aer_sim)
    obj = assemble(t_qc)
    results = aer_sim.run(obj, shots=shots).result()
    counts_sp = results.get_counts()

    # remove spaces from keys
    counts = {}
    for l_bin in counts_sp.keys():
        l_bin_nosp = l_bin.replace(" ", "")
        counts[l_bin_nosp] = counts_sp[l_bin]

    # plot phase histogram
    width = 0.1 # nomial values 0.1-0.2
    if N > 100: # use 0.2 for larger N
        width = 0.2
    plot_phase_hist(counts, N, a, u_ver, trnc_lv, width, filename, verbose)
    plt.savefig(filename+'_hist.jpg')
    plt.show()
