/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.autodiff.samediff.array;

import java.util.Collection;
import java.util.Collections;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.ArrayHolder;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.util.DeviceLocalNDArray;

public class ThreadSafeArrayHolder
implements ArrayHolder {
    private final Map<String, DeviceLocalNDArray> map = new ConcurrentHashMap<String, DeviceLocalNDArray>();
    private final boolean lazyInit;

    public ThreadSafeArrayHolder(boolean lazyInit) {
        this.lazyInit = lazyInit;
    }

    @Override
    public boolean hasArray(@NonNull String name) {
        if (name == null) {
            throw new NullPointerException("name is marked non-null but is null");
        }
        return this.map.containsKey(name);
    }

    @Override
    public INDArray getArray(@NonNull String name) {
        if (name == null) {
            throw new NullPointerException("name is marked non-null but is null");
        }
        if (!this.map.containsKey(name)) {
            return null;
        }
        return this.map.get(name).get();
    }

    @Override
    public void setArray(@NonNull String name, @NonNull INDArray array) {
        if (name == null) {
            throw new NullPointerException("name is marked non-null but is null");
        }
        if (array == null) {
            throw new NullPointerException("array is marked non-null but is null");
        }
        if (array.isView()) {
            array = array.dup();
        }
        if (!this.map.containsKey(name)) {
            INDArray toBroadcast = array.dataType() == DataType.UTF8 ? array.dup() : array;
            DeviceLocalNDArray dla = new DeviceLocalNDArray(toBroadcast, this.lazyInit);
            this.map.put(name, dla);
        } else {
            DeviceLocalNDArray dla = this.map.get(name);
            dla.update(array);
        }
    }

    @Override
    public INDArray removeArray(@NonNull String name) {
        if (name == null) {
            throw new NullPointerException("name is marked non-null but is null");
        }
        DeviceLocalNDArray arr = this.map.remove(name);
        if (arr == null) {
            return null;
        }
        return arr.get();
    }

    @Override
    public int size() {
        return this.map.size();
    }

    @Override
    public void initFrom(ArrayHolder arrayHolder) {
        this.map.clear();
        Collection<String> names = arrayHolder.arrayNames();
        for (String n : names) {
            this.setArray(n, arrayHolder.getArray(n));
        }
    }

    @Override
    public Collection<String> arrayNames() {
        return Collections.unmodifiableCollection(this.map.keySet());
    }

    @Override
    public void rename(@NonNull String from, @NonNull String to) {
        if (from == null) {
            throw new NullPointerException("from is marked non-null but is null");
        }
        if (to == null) {
            throw new NullPointerException("to is marked non-null but is null");
        }
        DeviceLocalNDArray dl = this.map.remove(from);
        this.map.put(to, dl);
    }
}

