/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.basicdataset;

import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.repository.Repository;
import ai.djl.repository.dataset.PreparedDataset;
import ai.djl.training.dataset.RandomAccessDataset;
import ai.djl.training.dataset.Record;
import ai.djl.translate.Pipeline;
import ai.djl.translate.Transform;
import ai.djl.util.Pair;
import ai.djl.util.PairList;
import java.io.File;
import java.io.IOException;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

public abstract class AbstractImageFolder
extends RandomAccessDataset
implements PreparedDataset {
    private static final Set<String> EXT = new HashSet<String>(Arrays.asList(".jpg", ".jpeg", ".png", ".bmp", ".wbmp", ".gif"));
    protected Repository repository;
    protected Image.Flag flag;
    protected List<String> synset;
    protected PairList<String, Integer> items;

    protected AbstractImageFolder(ImageFolderBuilder<?> builder) {
        super(builder);
        this.flag = builder.flag;
        this.repository = builder.repository;
        this.synset = new ArrayList<String>();
        this.items = new PairList();
    }

    public Record get(NDManager manager, long index) throws IOException {
        Pair item = this.items.get(Math.toIntExact(index));
        Path imagePath = this.getImagePath((String)item.getKey());
        NDArray array = ImageFactory.getInstance().fromFile(imagePath).toNDArray(manager, this.flag);
        NDList d = new NDList(new NDArray[]{array});
        NDList l = new NDList(new NDArray[]{manager.create((Number)item.getValue())});
        return new Record(d, l);
    }

    protected long availableSize() {
        return this.items.size();
    }

    public List<String> getSynset() {
        return this.synset;
    }

    protected void listImages(File root, List<String> classes) {
        int label = 0;
        for (String className : classes) {
            File[] files;
            File classFolder = new File(root, className);
            if (!classFolder.exists() || !classFolder.isDirectory() || (files = classFolder.listFiles(this::isImage)) == null) continue;
            for (File file : files) {
                String path = file.getAbsolutePath();
                this.items.add(new Pair((Object)path, (Object)label));
            }
            ++label;
        }
    }

    protected abstract Path getImagePath(String var1);

    private boolean isImage(File file) {
        String path = file.getName();
        if (!file.isFile() || file.isHidden() || path.startsWith(".")) {
            return false;
        }
        int extensionIndex = path.lastIndexOf(46);
        if (extensionIndex < 0) {
            return false;
        }
        return EXT.contains(path.substring(extensionIndex).toLowerCase());
    }

    public static abstract class ImageFolderBuilder<T extends ImageFolderBuilder>
    extends RandomAccessDataset.BaseBuilder<T> {
        Repository repository;
        Image.Flag flag = Image.Flag.COLOR;

        protected ImageFolderBuilder() {
            this.pipeline = new Pipeline(new Transform[]{new ToTensor()});
        }

        public T optFlag(Image.Flag flag) {
            this.flag = flag;
            return (T)((Object)((ImageFolderBuilder)this.self()));
        }

        public T setRepository(Repository repository) {
            this.repository = repository;
            return (T)((Object)((ImageFolderBuilder)this.self()));
        }
    }
}

