/*
 * Decompiled with CFR 0.152.
 */
package net.andreinc.markovneat;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ThreadLocalRandom;
import net.andreinc.markovneat.MProb;
import net.andreinc.markovneat.MState;

public class MChain<T> {
    protected Random random;
    protected Map<MState<T>, MProb<T>> chain = new ConcurrentHashMap<MState<T>, MProb<T>>();
    protected List<MState<T>> states = new ArrayList<MState<T>>();
    private final int noStates;

    public MChain() {
        this(1);
    }

    public MChain(int noStates) {
        this(noStates, ThreadLocalRandom.current());
    }

    public MChain(int noStates, Random random) {
        if (noStates < 1) {
            throw new IllegalArgumentException("The number of states used to create the Markov chain needs to be {@code >= 1}");
        }
        this.noStates = noStates;
        this.random = random;
    }

    public void add(MState<T> state, T element) {
        this.chain.putIfAbsent(state, new MProb(this.random));
        this.chain.get(state).add(1.0, element);
    }

    public void add(MState<T> state, T element, double weight) {
        this.chain.putIfAbsent(state, new MProb(this.random));
        this.chain.get(state).add(weight, element);
    }

    public void train(Iterable<T> elements) {
        this.train(elements.iterator());
    }

    public void train(T ... elements) {
        if (elements.length < this.noStates) {
            throw new IllegalArgumentException("Cannot train a chain with based on a number of elements smaller than noStates.");
        }
        this.train(Arrays.stream(elements).iterator());
    }

    protected void train(Iterator<T> iterator) {
        MState<T> state = new MState<T>();
        while (iterator.hasNext()) {
            if (state.data().size() < this.noStates) {
                state.data().add(iterator.next());
                continue;
            }
            T next = iterator.next();
            this.add(state, next);
            state = state.nextState(next);
        }
    }

    public List<T> generate(int numElements) {
        return this.generate(this.randomState(), numElements);
    }

    public MState<T> randomState() {
        if (this.chain.isEmpty()) {
            throw new IllegalArgumentException("Markov chain is empty. Please train the chain first.");
        }
        if (this.states.size() != this.chain.keySet().size()) {
            this.states = new ArrayList<MState<T>>(this.chain.keySet());
        }
        int idx = this.random.nextInt(this.states.size());
        return this.states.get(idx);
    }

    public List<T> generate(MState<T> initialState, int numElements) {
        if (this.chain.isEmpty()) {
            throw new IllegalArgumentException("Markov chain is empty. Please train the chain first.");
        }
        if (numElements <= 0) {
            throw new IllegalArgumentException("The initial number of elements cannot be negative or zero. (>0)");
        }
        if (!this.chain.containsKey(initialState)) {
            throw new IllegalArgumentException("The initial state cannot be found in the Markov Chain. Please use an existing state.");
        }
        ArrayList<T> result = new ArrayList<T>();
        MState<T> state = initialState.shallowCopy();
        result.addAll(state.data());
        int goUntil = numElements;
        while (goUntil-- > 0) {
            if (!this.chain.containsKey(state)) {
                state = this.randomState();
                result.addAll(state.data());
                continue;
            }
            T element = this.chain.get(state).next();
            result.add(element);
            state = state.nextState(element);
        }
        return result;
    }
}

