/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.training.listener;

import ai.djl.Model;
import ai.djl.training.Trainer;
import ai.djl.training.listener.TrainingListenerAdapter;
import java.io.IOException;
import java.nio.file.Paths;
import java.util.function.Consumer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class CheckpointsTrainingListener
extends TrainingListenerAdapter {
    private static final Logger logger = LoggerFactory.getLogger(CheckpointsTrainingListener.class);
    private String outputDir;
    private String overrideModelName;
    private Consumer<Trainer> onSaveModel;
    private int step;
    private int epoch;

    public CheckpointsTrainingListener(String outputDir) {
        this(outputDir, null, -1);
    }

    public CheckpointsTrainingListener(String outputDir, String overrideModelName) {
        this(outputDir, overrideModelName, -1);
    }

    public CheckpointsTrainingListener(String outputDir, String overrideModelName, int step) {
        this.outputDir = outputDir;
        this.step = step;
        if (outputDir == null) {
            throw new IllegalArgumentException("Can not save checkpoint without specifying an output directory");
        }
        this.overrideModelName = overrideModelName;
    }

    @Override
    public void onEpoch(Trainer trainer) {
        ++this.epoch;
        if (this.outputDir == null) {
            return;
        }
        if (this.step > 0 && this.epoch % this.step == 0) {
            this.saveModel(trainer);
        }
    }

    @Override
    public void onTrainingEnd(Trainer trainer) {
        if (this.step == -1 || this.epoch % this.step != 0) {
            this.saveModel(trainer);
        }
    }

    public String getOverrideModelName() {
        return this.overrideModelName;
    }

    public void setOverrideModelName(String overrideModelName) {
        this.overrideModelName = overrideModelName;
    }

    public void setSaveModelCallback(Consumer<Trainer> onSaveModel) {
        this.onSaveModel = onSaveModel;
    }

    protected void saveModel(Trainer trainer) {
        Model model = trainer.getModel();
        String modelName = model.getName();
        if (this.overrideModelName != null) {
            modelName = this.overrideModelName;
        }
        try {
            model.setProperty("Epoch", String.valueOf(this.epoch));
            if (this.onSaveModel != null) {
                this.onSaveModel.accept(trainer);
            }
            model.save(Paths.get(this.outputDir, new String[0]), modelName);
        }
        catch (IOException e) {
            logger.error("Failed to save checkpoint", (Throwable)e);
        }
    }
}

