# Example of creating an individual with multiple 'parts'

# Derek M Tishler, Oct 2020

'''
Evolve a complex structure/individual, such as encoded hyperparamters of a mlp classifier
    individual_example = [encoded_solver__int, learning_rate__float, hidden_units__list]
example of a 10 layer network using adam solver with learning rate ~0.007:
    [0, 0.006592, [13, 89, 21, 100, 96, 100, 16, 67, 45, 88]]
'''

import array
import random

import numpy

from deap import algorithms
from deap import base
from deap import creator
from deap import tools

from sklearn.neural_network import MLPClassifier
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split

from sklearn.utils.testing import ignore_warnings
from sklearn.exceptions import ConvergenceWarning

import multiprocessing

creator.create("FitnessMax", base.Fitness, weights=(1.0,))
creator.create("Individual", list, fitness=creator.FitnessMax)

toolbox = base.Toolbox()

random.seed(64)
numpy.random.seed(64)

### Data for training ###
X, y = make_classification(n_samples=100, n_features=20, n_classes=3, 
                           n_informative=3, random_state=1)
X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y,
                                                    random_state=1)
#########################

### Set up the hyperparams of mlp classifier ###
# https://scikit-learn.org/stable/modules/generated/sklearn.neural_network.MLPClassifier.html

# part 1 of individual
solvers = {
    0:"adam",
    1:"sgd",
    2:"lbfgs",
}

# part 2 of individual
init_learning_rate_range = (1e-8, 1e-2)

# part 3 of individual 
min_layers = 2
max_layers = 10
min_hidden_units = 3
max_hidden_units = 100
################################################

def create_random_mlp_network(cls):
    new_individial = []

    # part 1 is the encoded solver, such as 0='adam', 1='sgd', 2='lbfs'
    new_individial.append(random.randrange(len(solvers)))

    # Part 2 is the learning rate
    new_individial.append(random.uniform(*init_learning_rate_range))

    # Part 3 is the hidden_layer_sizes
    new_network = [random.randrange(min_hidden_units, max_hidden_units+1) \
                        for _ in range(random.randrange(min_layers, max_layers+1))]
    new_individial.append(new_network)

    return cls(new_individial)

@ignore_warnings(category=ConvergenceWarning)
def evalMLP(individual):

    clf = MLPClassifier(random_state=1, 
                        max_iter=100,

                        solver=solvers[individual[0]],
                        learning_rate_init=individual[1],
                        hidden_layer_sizes=individual[2],

                       ).fit(X_train, y_train)

    return clf.score(X_test, y_test),

def constrain_layers(hidden_layer_sizes):
    # Ensure the layers are never exceeding max len
    return hidden_layer_sizes[:max_layers]

def mate(ind1, ind2):
    # perform, on the first index only, a crossover of values
    if random.random() < 0.333:
        ind1[0], ind2[0] = ind2[0], ind1[0]

    # swap the learning rates
    if random.random() < 0.333:
        ind1[1], ind2[1] = ind2[1], ind1[1]

    # more interestingly, we can now use something like cxOnePoint to modify the layer lists
    if random.random() < 0.333:
        ind1[2],ind2[2] = tools.cxOnePoint(ind1[2],ind2[2])
        ind1[2] = constrain_layers(ind1[2])
        ind2[2] = constrain_layers(ind2[2])

    return ind1, ind2

def mutate(ind, indpb=0.05):
    # Random int in range, such as 0='adam', 1='sgd', 2='lbfs'
    if random.random() < indpb:
        ind[0] = random.randrange(len(solvers))

    # Random float in range 1e-8, 1e-2
    if random.random() < indpb:
        ind[1] = random.uniform(*init_learning_rate_range)

    # Performs actions like add or remove a layer, or randomly change n_hidden units in some layer
    if random.random() < indpb:
        # Set a random layer to a new random n_hidden_units
        ind[2][random.randrange(len(ind[2]))] = random.randrange(min_hidden_units,max_hidden_units)
    if random.random() < indpb:
        # insert a random layer into a random position
        if len(ind[2]) < max_layers:
            ind[2].insert(random.randrange(len(ind[2])),
                            random.randrange(min_hidden_units, max_hidden_units+1))
    if random.random() < indpb:
        # remove a random layer
        if len(ind[2]) > min_layers:
            del ind[2][random.randrange(len(ind[2]))]

    return ind,

# Structure initializers
toolbox.register("individual", create_random_mlp_network, creator.Individual)
toolbox.register("population", tools.initRepeat, list, toolbox.individual)

toolbox.register("evaluate", evalMLP)
toolbox.register("mate", mate)
toolbox.register("mutate", mutate, indpb=0.05)
toolbox.register("select", tools.selTournament, tournsize=3)

pool = multiprocessing.Pool(4)
toolbox.register("map", pool.map)

def main():

    NGEN = 10
    MU = 100
    LAMBDA = 2*MU
    CXPB = 0.7
    MUTPB = 0.2
    
    pop = toolbox.population(n=MU)
    hof = tools.HallOfFame(1)

    stats_fit = tools.Statistics(lambda ind: ind.fitness.values)
    stats_layers = tools.Statistics(lambda ind: len(ind[2]))
    stats_units = tools.Statistics(lambda ind: numpy.mean(ind[2]))
    stats = tools.MultiStatistics(fitness=stats_fit, layers=stats_layers, hidden_units=stats_units)
    stats.register("avg", numpy.mean)
    stats.register("std", numpy.std)
    stats.register("min", numpy.min)
    stats.register("max", numpy.max)
    
    #pop, log = algorithms.eaSimple(pop, toolbox, cxpb=0.5, mutpb=0.2, ngen=NGEN, 
    #                               stats=stats, halloffame=hof, verbose=True)

    pop, log = algorithms.eaMuPlusLambda(pop, toolbox, MU, LAMBDA, CXPB, MUTPB, NGEN, stats,
                              halloffame=hof, verbose=True)
    
    top_ind = hof[0]
    print("Best Score: %0.3f" % top_ind.fitness.values[0])
    print("Solver: %s" % solvers[top_ind[0]])
    print("Learning Rate Init: %g" % top_ind[1])
    print("Layers(%d): %s" % (len(top_ind[2]), str(top_ind[2])))
    print (top_ind)

    return pop, log, hof

if __name__ == "__main__":
    main()