/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.training;

import ai.djl.Device;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.training.GradientCollector;
import ai.djl.training.Trainer;
import ai.djl.training.dataset.Batch;
import ai.djl.training.listener.TrainingListener;
import java.util.ArrayList;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;

public final class ParallelTrain {
    private final ExecutorService executor;

    public ParallelTrain(Device[] devices) {
        this.executor = Executors.newFixedThreadPool(devices.length);
    }

    public void trainBatch(Trainer trainer, Batch batch) {
        if (trainer.getManager().getEngine() != batch.getManager().getEngine()) {
            throw new IllegalArgumentException("The data must be on the same engine as the trainer. You may need to change one of your NDManagers.");
        }
        Batch[] splits = batch.split(trainer.getDevices(), false);
        TrainingListener.BatchData batchData = new TrainingListener.BatchData(batch, new ConcurrentHashMap<Device, NDList>(), new ConcurrentHashMap<Device, NDList>());
        ArrayList<Future<Boolean>> futures = new ArrayList<Future<Boolean>>(splits.length);
        for (Batch split : splits) {
            futures.add(this.executor.submit(() -> {
                try (GradientCollector collector = trainer.newGradientCollector();){
                    NDList data = trainer.getDataManager().getData(split);
                    NDList labels = trainer.getDataManager().getLabels(split);
                    NDList preds = trainer.forward(data);
                    long time = System.nanoTime();
                    NDArray lossValue = trainer.getLoss().evaluate(labels, preds);
                    collector.backward(lossValue);
                    trainer.addMetric("backward", time);
                    time = System.nanoTime();
                    batchData.getLabels().put(((NDArray)labels.get(0)).getDevice(), labels);
                    batchData.getPredictions().put(((NDArray)preds.get(0)).getDevice(), preds);
                    trainer.addMetric("training-metrics", time);
                    Boolean bl = true;
                    return bl;
                }
            }));
        }
        for (Future future : futures) {
            try {
                future.get();
            }
            catch (InterruptedException e) {
                e.printStackTrace();
            }
            catch (ExecutionException e) {
                e.printStackTrace();
            }
        }
        trainer.notifyListeners(listener -> listener.onTrainingBatch(trainer, batchData));
    }
}

