#!/usr/bin/env python
#-----------------------------------------------------------------------------
# Model Factory Interface:
"""NOTES:
- forward model "forward_poly" calculates a function of x (w/ fixed a,b,c)
- ForwardPolyFactory is a "function generator", allowing a,b,c to be set
"""
def ForwardPolyFactory(params):
    a,b,c = params
    def forward_poly(x):
        """ x should be a 1D (1 by N) numpy array """
        return array((a*x*x + b*x + c))
    return forward_poly











#-----------------------------------------------------------------------------
# Forward Model Invocation:
"""NOTES:
- fwd is a instance of "forward_poly", built with chosen a,b,c 
- "data" converts a function of x into a function of a,b,c (w/ fixed x) [i.e. a functor]
- same methodology is used in COST FUNCTION to produce "goodness of fit"
"""
def data(params):
    fwd = ForwardPolyFactory(params)
    x = (array([range(101)])-50.)[0]
    return fwd(x)











#-----------------------------------------------------------------------------
# Build "Measured" Data: (optional... use real measured data)
"""NOTES:
- target is "target solution" for a,b,c
- data is used to generate "measured data" (parameters a,b,c = target)
"""
target = [1., 2., 1.]
datapts = data(target)















#-----------------------------------------------------------------------------
# Cost Function Generation: (optional... write your cost function explicitly)
"""NOTES:
- F is an instance of Cost Function (goodness of fit) generator 
- myCost is an instance of a Cost Function
- (default metric) calculates the LeastSquared difference for fwd(x) & datapts
"""
x = (array([range(101)])-50.)[0]
F = CostFactory()
F.addModel(ForwardPolyFactory,3,'poly')
myCost = F.getCostFunction(evalpts=x, observations=datapts)












#-----------------------------------------------------------------------------
# Call to Solver:
"""NOTES:
- solution is set of solved parameters a,b,c
- stepmon holds a log of optimization steps
"""
solution, stepmon = de_solve(myCost)






#-----------------------------------------------------------------------------
ND = 3
NP = 80
MAX_GENERATIONS = ND*NP
#-----------------------------------------------------------------------------
# Standard "Solver" Configuration:
"""NOTES:
- ND is number of parameters (a,b,c)
- NP is size of trial population
- MAX_GENERATIONS is maximum optimization iterations
#-----------------------------------------------------------------------------
- VerboseMonitor logs/prints "goodness of fit" and "best solution" at each step
- minrange/maxrange provide box constraints (for parameters a,b,c)
- SetRandomInitialPoints chooses an initial solution within box constraints
- SetStrictRanges only allows trial solutions within box constraints
- 'termination' conditions are to end when "no change" after 300 generations
- enable_signal_handler allows "interrupt" signal to be caught
- sigint_callback registers a user-provided function to the signal_handler
"""
def de_solve(CF):
    solver = DifferentialEvolutionSolver(ND, NP)
    solver.enable_signal_handler()

    stepmon = VerboseMonitor(10,50)
    minrange = [-100., -100., -100.]; maxrange = [100., 100., 100.];
    solver.SetRandomInitialPoints(min = minrange, max = maxrange)
    solver.SetStrictRanges(min = minrange, max = maxrange)
    solver.SetEvaluationLimits(maxiter=MAX_GENERATIONS)
    solver.SetGenerationMonitor(stepmon)

    solver.Solve(CF, ChangeOverGeneration(generations=300),\
                 CrossProbability=0.5, ScalingFactor=0.5,\
                 sigint_callback=plot_sol)

    solution = solver.Solution()
    return solution, stepmon
#-----------------------------------------------------------------------------
# BONUS... the Callback Function:
"""NOTES:
- called on "catch" of signal-interrupt
- _MUST_ be a function of "params"
- _only_one_ configuration parameter is (currently) allowed
"""
def plot_sol(params,linestyle='b-'):
    x = (array([range(101)])-50.)[0]
    d = data(params)
    pylab.plot(x,d,'%s'%linestyle,linewidth=2.0)
    pylab.axis(plotview)
    return










# DONE
