"""
Learning of an input pattern with STDP for testing with FACETS hardware.

Olivier Bichler, CEA LIST <olivier.bichler@cea.fr>
April 2010
"""

import numpy, pylab, math

from pyNN.utility import get_script_args

sim_name = get_script_args(1)[0] 
exec("from pyNN.%s import *" % sim_name)

# Parameters
#################

Tstep = 15.0        # Time of a learning step (ms)
Tinput = 5.0        # Duration of the input stimuli during a learning step (ms)    [Tstep > Tinput + tau_refrac]
Tstart = 0.1        # Start time of the input stimuli from the beginning of a learning step (ms)

Nstep = 60          # Number of learning steps
Nview = 20          # Number of steps displayed at the end of simulation
InputSize = 512;      # Input pattern length (= number of input synapses)

# Weights stored on 4 bits in hardware
Wmin = 0.0
Wmax = 0.005
Winit = 0.0003125

# STDP parameters
STDP_TauPlus = 5.0      # Can be decreased
STDP_TauMinus = 5.0
STDP_APlus = 0.03       # Can be increased
STDP_AMinus = 0.03

# Commented out parameters are not available with IF_facets_hardware1
NeuronParams = {
    'v_rest'     : -70.0,   # Resting membrane potential in mV. 
#    'cm'         : 1.0,     # Capacity of the membrane in nF
#    'tau_m'      : 20.0,    # Membrane time constant in ms.
    'tau_refrac' : 10.0,     # Duration of refractory period in ms.
                                # <--- not available with IF_facets_hardware1, but needs to be changed
    'tau_syn_E'  : 2.0,     # Decay time of the excitatory synaptic conductance in ms.            <--- 2.0 instead of 5.0
    'tau_syn_I'  : 5.0,     # Decay time of the inhibitory synaptic conductance in ms.
#    'e_rev_E'    : 0.0,     # Reversal potential for excitatory input in mV
    'e_rev_I'    : -75.0,   # Reversal potential for inhibitory input in mV
    'v_thresh'   : -55.0,   # Spike threshold in mV.
    'v_reset'    : -70.0,   # Reset potential after a spike in mV.
#    'i_offset'   : 0.0,     # Offset current in nA
#    'v_init'     : -70.0,   # Membrane potential in mV at t = 0
}

# Build networks
#################

setup(timestep=0.001, min_delay=0.1, max_delay=1.0, debug=True, quit_on_end=False)

p1 = Population(InputSize, SpikeSourceArray)

for i in range(InputSize):
    p1.all_cells[i].spike_times = numpy.arange(Tstart + float(i)/InputSize*Tinput, Nstep*Tstep, Tstep)

#p2 = Population(1, IF_facets_hardware1, NeuronParams)
p2 = Population(1, IF_cond_exp, NeuronParams)

stdp_model = STDPMechanism(timing_dependence=SpikePairRule(tau_plus=STDP_TauPlus, tau_minus=STDP_TauMinus),
                           weight_dependence=AdditiveWeightDependence(w_min=Wmin, w_max=Wmax,
                                                                      A_plus=STDP_APlus, A_minus=STDP_AMinus))

connection_method = AllToAllConnector(weights=Winit, delays=0.1)
prj = Projection(p1, p2, method=connection_method,
                 synapse_dynamics=SynapseDynamics(slow=stdp_model))


p2.record()

t = []
w = []

for i in range(Nstep):
    t.append(run(Tstep))
    w.append(prj.getWeights())

# Display results
#################

x = range(Nstep)
y = []
spk = p2.getSpikes()
for idx, t in spk:
    y.append(t - math.floor(t/Tstep)*Tstep)

assert len(x) == len(y), "Neuron fire more than once each learning step!"

pylab.figure(1)
pylab.plot(x, y)
pylab.xlabel("# step")
pylab.ylabel("Fire time (ms)")
pylab.grid()
pylab.title("Evolution of the neuron's activation time during STDP learning")

x = range(InputSize)
sp = math.ceil(math.sqrt(Nview))

pylab.figure(2)
pylab.subplots_adjust(left=0.1, bottom=0.05, right=0.9, top=0.9, wspace=0.3, hspace=0.7)

for i in range(Nview):
    n = i*Nstep/Nview
    pylab.subplot(sp, math.ceil(float(Nview)/sp), i+1)
    pylab.bar(x, w[n])
    pylab.grid()
    pylab.ylim([Wmin, Wmax])
    pylab.xlim([0, InputSize])
    pylab.title("step #%d" % (n+1))

pylab.show()

end()

