/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.ndarray.index;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.index.dim.NDIndexAll;
import ai.djl.ndarray.index.dim.NDIndexBooleans;
import ai.djl.ndarray.index.dim.NDIndexElement;
import ai.djl.ndarray.index.dim.NDIndexFixed;
import ai.djl.ndarray.index.dim.NDIndexNull;
import ai.djl.ndarray.index.dim.NDIndexPick;
import ai.djl.ndarray.index.dim.NDIndexSlice;
import ai.djl.ndarray.index.dim.NDIndexTake;
import java.util.ArrayList;
import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Stream;

public class NDIndex {
    private static final Pattern ITEM_PATTERN = Pattern.compile("(\\*)|((-?\\d+|\\{\\})?:(-?\\d+|\\{\\})?(:(-?\\d+|\\{\\}))?)|(-?\\d+|\\{\\})|null");
    private int rank = 0;
    private List<NDIndexElement> indices = new ArrayList<NDIndexElement>();
    private int ellipsisIndex = -1;

    public NDIndex() {
    }

    public NDIndex(String indices, Object ... args) {
        this();
        this.addIndices(indices, args);
    }

    public NDIndex(long ... indices) {
        this();
        this.addIndices(indices);
    }

    public static NDIndex sliceAxis(int axis, long min, long max) {
        NDIndex ind = new NDIndex();
        for (int i = 0; i < axis; ++i) {
            ind.addAllDim();
        }
        ind.addSliceDim(min, max);
        return ind;
    }

    public int getRank() {
        return this.rank;
    }

    public int getEllipsisIndex() {
        return this.ellipsisIndex;
    }

    public NDIndexElement get(int dimension) {
        return this.indices.get(dimension);
    }

    public List<NDIndexElement> getIndices() {
        return this.indices;
    }

    public final NDIndex addIndices(String indices, Object ... args) {
        String[] indexItems = indices.split(",");
        this.rank += indexItems.length;
        int argIndex = 0;
        for (int i = 0; i < indexItems.length; ++i) {
            if ("...".equals(indexItems[i].trim())) {
                if (this.ellipsisIndex != -1) {
                    throw new IllegalArgumentException("an index can only have a single ellipsis (\"...\")");
                }
                this.ellipsisIndex = i;
                continue;
            }
            argIndex = this.addIndexItem(indexItems[i], argIndex, args);
        }
        if (this.ellipsisIndex != -1) {
            --this.rank;
        }
        if (argIndex != args.length) {
            throw new IllegalArgumentException("Incorrect number of index arguments");
        }
        return this;
    }

    public final NDIndex addIndices(long ... indices) {
        this.rank += indices.length;
        for (long i : indices) {
            this.indices.add(new NDIndexFixed(i));
        }
        return this;
    }

    public NDIndex addBooleanIndex(NDArray index) {
        this.rank += index.getShape().dimension();
        this.indices.add(new NDIndexBooleans(index));
        return this;
    }

    public NDIndex addAllDim() {
        ++this.rank;
        this.indices.add(new NDIndexAll());
        return this;
    }

    public NDIndex addAllDim(int count) {
        if (count < 0) {
            throw new IllegalArgumentException("The number of index dimensions to add can't be negative");
        }
        this.rank += count;
        for (int i = 0; i < count; ++i) {
            this.indices.add(new NDIndexAll());
        }
        return this;
    }

    public NDIndex addSliceDim(long min, long max) {
        ++this.rank;
        this.indices.add(new NDIndexSlice(min, max, null));
        return this;
    }

    public NDIndex addSliceDim(long min, long max, long step) {
        ++this.rank;
        this.indices.add(new NDIndexSlice(min, max, step));
        return this;
    }

    public NDIndex addPickDim(NDArray index) {
        ++this.rank;
        this.indices.add(new NDIndexPick(index));
        return this;
    }

    public Stream<NDIndexElement> stream() {
        return this.indices.stream();
    }

    private int addIndexItem(String indexItem, int argIndex, Object[] args) {
        Matcher m = ITEM_PATTERN.matcher(indexItem = indexItem.trim());
        if (!m.matches()) {
            throw new IllegalArgumentException("Invalid argument index: " + indexItem);
        }
        if ("null".equals(indexItem)) {
            this.indices.add(new NDIndexNull());
            return argIndex;
        }
        String star = m.group(1);
        if (star != null) {
            this.indices.add(new NDIndexAll());
            return argIndex;
        }
        String digit = m.group(7);
        if (digit != null) {
            if ("{}".equals(digit)) {
                Object arg = args[argIndex];
                if (arg instanceof Integer) {
                    this.indices.add(new NDIndexFixed(((Integer)arg).intValue()));
                    return argIndex + 1;
                }
                if (arg instanceof Long) {
                    this.indices.add(new NDIndexFixed((Long)arg));
                    return argIndex + 1;
                }
                if (arg instanceof NDArray) {
                    NDArray array = (NDArray)arg;
                    if (array.getDataType().isBoolean()) {
                        this.indices.add(new NDIndexBooleans(array));
                        return argIndex + 1;
                    }
                    if (array.getDataType().isInteger() || array.getDataType().isFloating()) {
                        this.indices.add(new NDIndexTake(array));
                        return argIndex + 1;
                    }
                } else if (arg == null) {
                    this.indices.add(new NDIndexNull());
                    return argIndex + 1;
                }
                throw new IllegalArgumentException("Unknown argument: " + arg);
            }
            this.indices.add(new NDIndexFixed(Long.parseLong(digit)));
            return argIndex;
        }
        Long min = null;
        Long max = null;
        Long step = null;
        if (m.group(3) != null) {
            min = this.parseSliceItem(m.group(3), argIndex, args);
            if ("{}".equals(m.group(3))) {
                ++argIndex;
            }
        }
        if (m.group(4) != null) {
            max = this.parseSliceItem(m.group(4), argIndex, args);
            if ("{}".equals(m.group(4))) {
                ++argIndex;
            }
        }
        if (m.group(6) != null) {
            step = this.parseSliceItem(m.group(6), argIndex, args);
            if ("{}".equals(m.group(6))) {
                ++argIndex;
            }
        }
        if (min == null && max == null && step == null) {
            this.indices.add(new NDIndexAll());
        } else {
            this.indices.add(new NDIndexSlice(min, max, step));
        }
        return argIndex;
    }

    private Long parseSliceItem(String sliceItem, int argIndex, Object ... args) {
        if ("{}".equals(sliceItem)) {
            Object arg = args[argIndex];
            if (arg instanceof Integer) {
                return ((Integer)arg).longValue();
            }
            if (arg instanceof Long) {
                return (Long)arg;
            }
            throw new IllegalArgumentException("Unknown slice argument: " + arg);
        }
        return Long.parseLong(sliceItem);
    }
}

