/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.tensorflow.engine;

import ai.djl.BaseModel;
import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDManager;
import ai.djl.nn.Block;
import ai.djl.tensorflow.engine.SavedModelBundle;
import ai.djl.tensorflow.engine.TfNDManager;
import ai.djl.tensorflow.engine.TfSymbolBlock;
import ai.djl.tensorflow.engine.javacpp.JavacppUtils;
import ai.djl.util.Utils;
import com.google.protobuf.InvalidProtocolBufferException;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.nio.file.FileVisitOption;
import java.nio.file.Files;
import java.nio.file.LinkOption;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Base64;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
import org.tensorflow.proto.framework.ConfigProto;
import org.tensorflow.proto.framework.RunOptions;

public class TfModel
extends BaseModel {
    private static final String DEFAULT_SERVING_SIGNATURE_DEF_KEY = "serving_default";

    TfModel(String name, Device device) {
        super(name);
        this.properties = new ConcurrentHashMap();
        this.manager = TfNDManager.getSystemManager().newSubManager(device);
        this.manager.setName("tfModel");
    }

    public void load(Path modelPath, String prefix, Map<String, ?> options) throws FileNotFoundException, MalformedModelException {
        Path exportDir;
        this.setModelDir(modelPath);
        this.wasLoaded = true;
        if (prefix == null) {
            prefix = this.modelName;
        }
        if ((exportDir = this.findModelDir(prefix)) == null && (exportDir = this.findModelDir("saved_model.pb")) == null) {
            throw new FileNotFoundException("No TensorFlow model found in: " + this.modelDir);
        }
        String[] tags = null;
        ConfigProto configProto = null;
        RunOptions runOptions = null;
        String signatureDefKey = DEFAULT_SERVING_SIGNATURE_DEF_KEY;
        if (options != null) {
            Object tagOption = options.get("Tags");
            if (tagOption instanceof String[]) {
                tags = (String[])tagOption;
            } else if (tagOption instanceof String) {
                tags = ((String)tagOption).isEmpty() ? Utils.EMPTY_ARRAY : ((String)tagOption).split(",");
            }
            Object config = options.get("ConfigProto");
            if (config instanceof ConfigProto) {
                configProto = (ConfigProto)config;
            } else if (config instanceof String) {
                try {
                    byte[] buf = Base64.getDecoder().decode((String)config);
                    configProto = ConfigProto.parseFrom((byte[])buf);
                }
                catch (InvalidProtocolBufferException e) {
                    throw new MalformedModelException("Invalid ConfigProto: " + config, (Throwable)e);
                }
            }
            Object run = options.get("RunOptions");
            if (run instanceof RunOptions) {
                runOptions = (RunOptions)run;
            } else if (run instanceof String) {
                try {
                    byte[] buf = Base64.getDecoder().decode((String)run);
                    runOptions = RunOptions.parseFrom((byte[])buf);
                }
                catch (InvalidProtocolBufferException e) {
                    throw new MalformedModelException("Invalid RunOptions: " + run, (Throwable)e);
                }
            }
            if (options.containsKey("SignatureDefKey")) {
                signatureDefKey = (String)options.get("SignatureDefKey");
            }
        }
        if (tags == null) {
            tags = new String[]{"serve"};
        }
        if (configProto == null) {
            configProto = JavacppUtils.getSessionConfig();
        }
        SavedModelBundle bundle = JavacppUtils.loadSavedModelBundle(exportDir.toString(), tags, configProto, runOptions);
        this.block = new TfSymbolBlock(bundle, signatureDefKey);
    }

    private Path findModelDir(String prefix) {
        Path file;
        Path path = this.modelDir.resolve(prefix);
        if (!Files.exists(path, new LinkOption[0])) {
            return null;
        }
        if (Files.isRegularFile(path, new LinkOption[0])) {
            return this.modelDir;
        }
        if (Files.isDirectory(path, new LinkOption[0]) && Files.exists(file = path.resolve("saved_model.pb"), new LinkOption[0]) && Files.isRegularFile(file, new LinkOption[0])) {
            return path;
        }
        return null;
    }

    public void save(Path modelPath, String newModelName) {
        throw new UnsupportedOperationException("Not supported for TensorFlow Engine");
    }

    public Block getBlock() {
        return this.block;
    }

    public void setBlock(Block block) {
        throw new UnsupportedOperationException("Not supported for TensorFlow Engine");
    }

    public NDManager getNDManager() {
        return this.manager;
    }

    public String[] getArtifactNames() {
        try {
            List files = Files.walk(this.modelDir, new FileVisitOption[0]).filter(x$0 -> Files.isRegularFile(x$0, new LinkOption[0])).collect(Collectors.toList());
            ArrayList<String> ret = new ArrayList<String>(files.size());
            for (Path path : files) {
                String fileName = path.toFile().getName();
                if (fileName.endsWith(".pb")) continue;
                Path relative = this.modelDir.relativize(path);
                ret.add(relative.toString());
            }
            return ret.toArray(Utils.EMPTY_ARRAY);
        }
        catch (IOException e) {
            throw new AssertionError("Failed list files", e);
        }
    }

    public void close() {
        if (this.block != null) {
            ((TfSymbolBlock)this.block).close();
            this.block = null;
        }
        super.close();
    }
}

