/*
 * Decompiled with CFR 0.152.
 */
package org.neuroph.nnet.learning;

import java.util.ArrayList;
import org.neuroph.core.Connection;
import org.neuroph.core.Layer;
import org.neuroph.core.Neuron;
import org.neuroph.core.transfer.Gaussian;
import org.neuroph.nnet.learning.LMS;
import org.neuroph.nnet.learning.kmeans.Cluster;
import org.neuroph.nnet.learning.kmeans.KMeansClustering;
import org.neuroph.nnet.learning.kmeans.KVector;
import org.neuroph.nnet.learning.knn.KNearestNeighbour;

public class RBFLearning
extends LMS {
    int k = 2;

    @Override
    protected void onStart() {
        super.onStart();
        KMeansClustering kmeans = new KMeansClustering(this.getTrainingSet());
        kmeans.setNumberOfClusters(this.neuralNetwork.getLayerAt(1).getNeuronsCount());
        kmeans.doClustering();
        Cluster[] clusters = kmeans.getClusters();
        Layer rbfLayer = this.neuralNetwork.getLayerAt(1);
        int i = 0;
        for (Neuron neuron : rbfLayer.getNeurons()) {
            KVector centroid = clusters[i].getCentroid();
            double[] weightValues = centroid.getValues();
            int c = 0;
            for (Connection conn : neuron.getInputConnections()) {
                conn.getWeight().setValue(weightValues[c]);
                ++c;
            }
            ++i;
        }
        ArrayList<KVector> centroids = new ArrayList<KVector>();
        for (Cluster cluster : clusters) {
            centroids.add(cluster.getCentroid());
        }
        KNearestNeighbour knn = new KNearestNeighbour();
        knn.setDataSet(centroids);
        int n = 0;
        for (KVector centroid : centroids) {
            KVector[] nearestNeighbours = knn.getKNearestNeighbours(centroid, this.k);
            double sigma = this.calculateSigma(centroid, nearestNeighbours);
            Neuron neuron = rbfLayer.getNeuronAt(n);
            ((Gaussian)neuron.getTransferFunction()).setSigma(sigma);
            ++i;
        }
    }

    private double calculateSigma(KVector centroid, KVector[] nearestNeighbours) {
        double sigma = 0.0;
        for (KVector nn : nearestNeighbours) {
            sigma += Math.pow(centroid.distanceFrom(nn), 2.0);
        }
        sigma = Math.sqrt(1.0 / (double)nearestNeighbours.length * sigma);
        return sigma;
    }
}

