package com.qualcomm.qti.snpe.internal;

import android.os.SystemClock;
import android.util.Log;
import android.util.Pair;
import com.qualcomm.qti.snpe.FloatTensor;
import com.qualcomm.qti.snpe.NeuralNetwork;
import com.qualcomm.qti.snpe.SNPE;
import com.qualcomm.qti.snpe.SnpeError;
import com.qualcomm.qti.snpe.TF8UserBufferTensor;
import com.qualcomm.qti.snpe.Tensor;
import com.qualcomm.qti.snpe.TensorAttributes;
import com.qualcomm.qti.snpe.UserBufferTensor;
import com.qualcomm.qti.snpe.internal.util.JniInputBundle;
import com.qualcomm.qti.snpe.internal.util.JniOutputBundle;
import com.qualcomm.qti.snpe.internal.util.NativeBufferOutputStream;
import java.io.File;
import java.nio.ByteBuffer;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;

/* loaded from: classes2.dex */
public final class NativeNetwork implements NeuralNetwork {
    private static final HashMap<NeuralNetwork.Runtime, Boolean> mRuntimeAvailability = new HashMap<>();
    private TensorAttributesMap mInputTensors;
    private boolean mIsDebugEnabled;
    private String mModelVersion;
    private long mNativeHandle;
    private Set<String> mOutputLayerNames;
    private TensorAttributesMap mOutputTensors;
    private NeuralNetwork.Runtime mTargetRuntime;

    public NativeNetwork(NeuralNetwork.Runtime runtime, NativeBufferOutputStream nativeBufferOutputStream, Set<String> set, boolean z, NeuralNetwork.PerformanceProfile performanceProfile, NeuralNetwork.ExecutionPriorityHint executionPriorityHint, boolean z2, boolean z3, boolean z4, boolean z5, String str, String str2, Map<String, int[]> map) {
        ByteBuffer marshall = createInitBundle(runtime, set, z, performanceProfile, executionPriorityHint, z2, z3, z4, z5, str, str2, map).marshall();
        long elapsedRealtime = SystemClock.elapsedRealtime();
        try {
            this.mNativeHandle = nativeInitFromBuffer(nativeBufferOutputStream.getNativeAddress(), marshall, marshall.capacity());
            long elapsedRealtime2 = SystemClock.elapsedRealtime();
            if (z) {
                Log.d(SNPE.LOG_TAG, "Network initialization took " + (elapsedRealtime2 - elapsedRealtime) + " ms");
            }
            init(runtime, set, z);
        } catch (IllegalStateException e) {
            throw SnpeError.getInstance().getSnpeNativeError(e.getMessage());
        }
    }

    public NativeNetwork(NeuralNetwork.Runtime runtime, File file, Set<String> set, boolean z, NeuralNetwork.PerformanceProfile performanceProfile, NeuralNetwork.ExecutionPriorityHint executionPriorityHint, boolean z2, boolean z3, boolean z4, boolean z5, String str, String str2, Map<String, int[]> map) {
        ByteBuffer marshall = createInitBundle(runtime, set, z, performanceProfile, executionPriorityHint, z2, z3, z4, z5, str, str2, map).marshall();
        long elapsedRealtime = SystemClock.elapsedRealtime();
        try {
            this.mNativeHandle = nativeInitFromFile(file.getAbsolutePath(), marshall, marshall.capacity());
            long elapsedRealtime2 = SystemClock.elapsedRealtime();
            if (z) {
                Log.d(SNPE.LOG_TAG, "Network initialization took " + (elapsedRealtime2 - elapsedRealtime) + " ms");
            }
            init(runtime, set, z);
        } catch (IllegalStateException e) {
            throw SnpeError.getInstance().getSnpeNativeError(e.getMessage());
        }
    }

    private void assertInputTensor(String str, Tensor tensor) {
        if (tensor == null) {
            throw SnpeError.getInstance().getInvalidInputError(SnpeError.Error.INVALID_TENSOR_NULL, str);
        }
        if (tensor instanceof NativeFloatTensor) {
            if (!this.mInputTensors.containsKey(str)) {
                throw SnpeError.getInstance().getInvalidInputError(SnpeError.Error.INVALID_INPUT_NETWORK_TENSORS, str);
            }
            TensorUtil.assertShape(str, this.mInputTensors.get(str).getDims(), tensor.getShape());
        } else {
            throw SnpeError.getInstance().getInvalidInputError(SnpeError.Error.INVALID_TENSOR_INSTANCE, "" + NativeFloatTensor.class);
        }
    }

    private void assertInputTensor(String str, UserBufferTensor userBufferTensor) {
        if (userBufferTensor == null) {
            throw SnpeError.getInstance().getInvalidInputError(SnpeError.Error.INVALID_TENSOR_NULL, str);
        }
        if (userBufferTensor instanceof NativeUserBufferHandle) {
            if (!this.mInputTensors.containsKey(str)) {
                throw SnpeError.getInstance().getInvalidInputError(SnpeError.Error.INVALID_INPUT_NETWORK_TENSORS, str);
            }
            return;
        }
        throw SnpeError.getInstance().getInvalidInputError(SnpeError.Error.INVALID_TENSOR_INSTANCE, "" + NativeUserBufferHandle.class);
    }

    private void assertOutputTensor(String str, UserBufferTensor userBufferTensor) {
        if (userBufferTensor == null) {
            throw SnpeError.getInstance().getInvalidInputError(SnpeError.Error.INVALID_TENSOR_NULL, str);
        }
        if (userBufferTensor instanceof NativeUserBufferHandle) {
            if (!this.mOutputTensors.containsKey(str)) {
                throw SnpeError.getInstance().getInvalidInputError(SnpeError.Error.INVALID_OUTPUT_NETWORK_TENSORS, str);
            }
            return;
        }
        throw SnpeError.getInstance().getInvalidInputError(SnpeError.Error.INVALID_OUTPUT_INSTANCE, "" + NativeUserBufferHandle.class);
    }

    private static void assertSelectedOutputLayer(Set<String> set, Set<String> set2) {
        for (String str : set2) {
            if (!set.contains(str)) {
                throw SnpeError.getInstance().getIllegalError(SnpeError.Error.ILLEGAL_SELECTED_LAYER, str);
            }
        }
    }

    private static JniInputBundle createInitBundle(NeuralNetwork.Runtime runtime, Set<String> set, boolean z, NeuralNetwork.PerformanceProfile performanceProfile, NeuralNetwork.ExecutionPriorityHint executionPriorityHint, boolean z2, boolean z3, boolean z4, boolean z5, String str, String str2, Map<String, int[]> map) {
        JniInputBundle jniInputBundle = new JniInputBundle();
        jniInputBundle.writeInt(runtime.ordinal);
        jniInputBundle.writeInt(performanceProfile.ordinal);
        jniInputBundle.writeInt(executionPriorityHint.ordinal);
        jniInputBundle.writeStrings(set);
        jniInputBundle.writeBoolean(z);
        jniInputBundle.writeBoolean(z2);
        jniInputBundle.writeBoolean(z3);
        jniInputBundle.writeBoolean(z4);
        jniInputBundle.writeBoolean(z5);
        jniInputBundle.writeString(str);
        jniInputBundle.writeString(str2);
        jniInputBundle.writeInt(map.size());
        for (Map.Entry<String, int[]> entry : map.entrySet()) {
            jniInputBundle.writeString(entry.getKey());
            jniInputBundle.writeInts(entry.getValue());
        }
        return jniInputBundle;
    }

    private void doRelease() {
        long j = this.mNativeHandle;
        if (j != 0) {
            try {
                nativeRelease(j);
                this.mNativeHandle = 0L;
            } catch (IllegalStateException e) {
                throw SnpeError.getInstance().getSnpeNativeError(e.getMessage());
            }
        }
    }

    private void init(NeuralNetwork.Runtime runtime, Set<String> set, boolean z) {
        try {
            JniOutputBundle nativeGetModelMetadata = nativeGetModelMetadata(this.mNativeHandle);
            this.mModelVersion = nativeGetModelMetadata.readString();
            this.mOutputLayerNames = readOutputLayerNames(nativeGetModelMetadata);
            this.mInputTensors = new TensorAttributesMap(nativeGetModelMetadata);
            this.mOutputTensors = new TensorAttributesMap(nativeGetModelMetadata);
            this.mTargetRuntime = runtime;
            this.mIsDebugEnabled = z;
            assertSelectedOutputLayer(this.mOutputLayerNames, set);
        } catch (IllegalStateException e) {
            throw SnpeError.getInstance().getSnpeNativeError(e.getMessage());
        }
    }

    public static synchronized boolean isRuntimeAvailable(NeuralNetwork.Runtime runtime) {
        boolean booleanValue;
        synchronized (NativeNetwork.class) {
            try {
                if (!mRuntimeAvailability.containsKey(runtime)) {
                    mRuntimeAvailability.put(runtime, Boolean.valueOf(nativeIsRuntimeAvailable(runtime.ordinal)));
                }
                booleanValue = mRuntimeAvailability.get(runtime).booleanValue();
            } catch (IllegalStateException e) {
                throw SnpeError.getInstance().getSnpeNativeError(e.getMessage());
            }
        }
        return booleanValue;
    }

    private native JniOutputBundle nativeForward(long j, ByteBuffer byteBuffer, int i, ByteBuffer[] byteBufferArr);

    private native JniOutputBundle nativeGetModelMetadata(long j);

    private native long nativeInitFromBuffer(long j, ByteBuffer byteBuffer, int i);

    private native long nativeInitFromFile(String str, ByteBuffer byteBuffer, int i);

    private static native boolean nativeIsRuntimeAvailable(int i);

    private native void nativeRelease(long j);

    private native JniOutputBundle nativeUserBufferForward(long j, ByteBuffer byteBuffer, int i);

    private static Set<String> readOutputLayerNames(JniOutputBundle jniOutputBundle) {
        HashSet hashSet = new HashSet();
        int readInt = jniOutputBundle.readInt();
        for (int i = 0; i < readInt; i++) {
            hashSet.add(jniOutputBundle.readString());
        }
        return hashSet;
    }

    private Map<String, FloatTensor> readOutputTensors(JniOutputBundle jniOutputBundle) {
        HashMap hashMap = new HashMap();
        int readInt = jniOutputBundle.readInt();
        for (int i = 0; i < readInt; i++) {
            Pair<String, FloatTensor> readTensor = readTensor(jniOutputBundle);
            hashMap.put(readTensor.first, readTensor.second);
        }
        return hashMap;
    }

    private boolean readOutputUserBufferTensors(JniOutputBundle jniOutputBundle, Map<String, ? extends UserBufferTensor> map) {
        int readInt = jniOutputBundle.readInt();
        for (int i = 0; i < readInt; i++) {
            String readString = jniOutputBundle.readString();
            if (!map.containsKey(readString)) {
                return false;
            }
            readUserBufferTensor(jniOutputBundle, map.get(readString));
        }
        return true;
    }

    private Pair<String, FloatTensor> readTensor(JniOutputBundle jniOutputBundle) {
        String readString = jniOutputBundle.readString();
        int[] readIntArray = jniOutputBundle.readIntArray();
        float[] readFloatArray = jniOutputBundle.readFloatArray();
        NativeFloatTensor nativeFloatTensor = new NativeFloatTensor(readIntArray);
        nativeFloatTensor.write(readFloatArray, 0, readFloatArray.length, new int[0]);
        return Pair.create(readString, nativeFloatTensor);
    }

    private void readUserBufferTensor(JniOutputBundle jniOutputBundle, UserBufferTensor userBufferTensor) {
        if (userBufferTensor.getEncoding() == UserBufferTensor.Encoding.TF8) {
            if (!(userBufferTensor instanceof NativeTF8UserBufferTensor)) {
                throw SnpeError.getInstance().getIllegalError(SnpeError.Error.ILLEGAL_USER_BUFFER_TENSOR);
            }
            NativeTF8UserBufferTensor nativeTF8UserBufferTensor = (NativeTF8UserBufferTensor) NativeTF8UserBufferTensor.class.cast(userBufferTensor);
            nativeTF8UserBufferTensor.setStepExactly0(jniOutputBundle.readInt());
            nativeTF8UserBufferTensor.setQuantizedStepSize(jniOutputBundle.readFloat());
        }
    }

    private List<ByteBuffer> writeInputTensors(JniInputBundle jniInputBundle, Map<String, NativeFloatTensor> map) {
        jniInputBundle.writeInt(map.size());
        LinkedList linkedList = new LinkedList();
        for (String str : map.keySet()) {
            linkedList.add(writeTensor(jniInputBundle, str, map.get(str)));
        }
        return linkedList;
    }

    private ByteBuffer writeTensor(JniInputBundle jniInputBundle, String str, NativeFloatTensor nativeFloatTensor) {
        jniInputBundle.writeLong(nativeFloatTensor.getNativeHandle());
        jniInputBundle.writeString(str);
        ByteBuffer asByteBuffer = nativeFloatTensor.asByteBuffer();
        jniInputBundle.writeInt(asByteBuffer.capacity());
        return asByteBuffer;
    }

    private void writeUserBufferTensors(JniInputBundle jniInputBundle, Map<String, ? extends UserBufferTensor> map) {
        jniInputBundle.writeInt(map.size());
        for (String str : map.keySet()) {
            jniInputBundle.writeLong(((NativeUserBufferHandle) NativeUserBufferHandle.class.cast(map.get(str))).getNativeHandle());
            jniInputBundle.writeString(str);
        }
    }

    @Override // com.qualcomm.qti.snpe.NeuralNetwork
    public FloatTensor createFloatTensor(int... iArr) {
        if (iArr == null || iArr.length == 0) {
            throw SnpeError.getInstance().getIllegalError(SnpeError.Error.ILLEGAL_TENSOR_SHAPE_NULL_EMPTY);
        }
        for (int i : iArr) {
            if (i < 1) {
                throw SnpeError.getInstance().getIllegalError(SnpeError.Error.ILLEGAL_SHAPE_DIM);
            }
        }
        return new NativeFloatTensor(iArr);
    }

    @Override // com.qualcomm.qti.snpe.NeuralNetwork
    public TF8UserBufferTensor createTF8UserBufferTensor(int i, int[] iArr, int i2, float f, ByteBuffer byteBuffer) {
        if (i <= 0) {
            throw SnpeError.getInstance().getIllegalError(SnpeError.Error.ILLEGAL_BUFFER_SIZE);
        }
        for (int i3 : iArr) {
            if (i3 < 1) {
                throw SnpeError.getInstance().getIllegalError(SnpeError.Error.ILLEGAL_STRIDE);
            }
        }
        if (i2 < 0 || i2 > 255) {
            throw SnpeError.getInstance().getIllegalError(SnpeError.Error.ILLEGAL_STEP_ZERO);
        }
        if (f <= 0.0d) {
            throw SnpeError.getInstance().getIllegalError(SnpeError.Error.ILLEGAL_STEP_VALUE);
        }
        if (byteBuffer != null) {
            return new NativeTF8UserBufferTensor(i, iArr, i2, f, byteBuffer);
        }
        throw SnpeError.getInstance().getIllegalError(SnpeError.Error.ILLEGAL_BUFFER_NULL);
    }

    @Override // com.qualcomm.qti.snpe.NeuralNetwork
    public Map<String, FloatTensor> execute(Map<String, FloatTensor> map) {
        if (map == null || map.isEmpty()) {
            throw SnpeError.getInstance().getInvalidInputError(SnpeError.Error.INVALID_INPUT_TENSOR_NULL_EMPTY);
        }
        HashMap hashMap = new HashMap();
        for (String str : map.keySet()) {
            Tensor tensor = map.get(str);
            assertInputTensor(str, tensor);
            hashMap.put(str, NativeFloatTensor.class.cast(tensor));
        }
        JniInputBundle jniInputBundle = new JniInputBundle();
        List<ByteBuffer> writeInputTensors = writeInputTensors(jniInputBundle, hashMap);
        ByteBuffer marshall = jniInputBundle.marshall();
        long elapsedRealtime = SystemClock.elapsedRealtime();
        try {
            JniOutputBundle nativeForward = nativeForward(this.mNativeHandle, marshall, marshall.capacity(), (ByteBuffer[]) writeInputTensors.toArray(new ByteBuffer[writeInputTensors.size()]));
            long elapsedRealtime2 = SystemClock.elapsedRealtime();
            if (this.mIsDebugEnabled) {
                Log.d(SNPE.LOG_TAG, "Network forward took " + (elapsedRealtime2 - elapsedRealtime) + " ms");
            }
            return readOutputTensors(nativeForward);
        } catch (IllegalStateException e) {
            throw SnpeError.getInstance().getSnpeNativeError(e.getMessage());
        }
    }

    @Override // com.qualcomm.qti.snpe.NeuralNetwork
    public boolean execute(Map<String, ? extends UserBufferTensor> map, Map<String, ? extends UserBufferTensor> map2) {
        if (map == null || map.isEmpty()) {
            throw SnpeError.getInstance().getInvalidInputError(SnpeError.Error.INVALID_INPUT_TENSOR_NULL_EMPTY);
        }
        if (map2 == null || map2.isEmpty()) {
            throw SnpeError.getInstance().getInvalidInputError(SnpeError.Error.INVALID_OUTPUT_TENSOR_NULL_EMPTY);
        }
        for (String str : map.keySet()) {
            assertInputTensor(str, map.get(str));
        }
        for (String str2 : map2.keySet()) {
            assertOutputTensor(str2, map2.get(str2));
        }
        JniInputBundle jniInputBundle = new JniInputBundle();
        writeUserBufferTensors(jniInputBundle, map);
        writeUserBufferTensors(jniInputBundle, map2);
        ByteBuffer marshall = jniInputBundle.marshall();
        long elapsedRealtime = SystemClock.elapsedRealtime();
        try {
            JniOutputBundle nativeUserBufferForward = nativeUserBufferForward(this.mNativeHandle, marshall, marshall.capacity());
            long elapsedRealtime2 = SystemClock.elapsedRealtime();
            if (this.mIsDebugEnabled) {
                Log.d(SNPE.LOG_TAG, "Network forward took " + (elapsedRealtime2 - elapsedRealtime) + " ms");
            }
            return readOutputUserBufferTensors(nativeUserBufferForward, map2);
        } catch (Exception e) {
            throw SnpeError.getInstance().getSnpeNativeError(e.getMessage());
        }
    }

    protected void finalize() throws Throwable {
        super.finalize();
        doRelease();
    }

    @Override // com.qualcomm.qti.snpe.NeuralNetwork
    public Set<String> getInputTensorsNames() {
        return this.mInputTensors.keySet();
    }

    @Override // com.qualcomm.qti.snpe.NeuralNetwork
    public Map<String, int[]> getInputTensorsShapes() {
        return new HashMap(this.mInputTensors.getTensorShapes());
    }

    @Override // com.qualcomm.qti.snpe.NeuralNetwork
    public String getModelVersion() {
        return this.mModelVersion;
    }

    @Override // com.qualcomm.qti.snpe.NeuralNetwork
    public Set<String> getOutputLayers() {
        return this.mOutputLayerNames;
    }

    @Override // com.qualcomm.qti.snpe.NeuralNetwork
    public Set<String> getOutputTensorsNames() {
        return this.mOutputTensors.keySet();
    }

    @Override // com.qualcomm.qti.snpe.NeuralNetwork
    public Map<String, int[]> getOutputTensorsShapes() {
        return new HashMap(this.mOutputTensors.getTensorShapes());
    }

    @Override // com.qualcomm.qti.snpe.NeuralNetwork
    public NeuralNetwork.Runtime getRuntime() {
        return this.mTargetRuntime;
    }

    @Override // com.qualcomm.qti.snpe.NeuralNetwork
    public TensorAttributes getTensorAttributes(String str) {
        if (this.mInputTensors.containsKey(str)) {
            return this.mInputTensors.get(str);
        }
        if (this.mOutputTensors.containsKey(str)) {
            return this.mOutputTensors.get(str);
        }
        throw SnpeError.getInstance().getIllegalError(SnpeError.Error.ILLEGAL_TENSOR_NAME, str);
    }

    @Override // com.qualcomm.qti.snpe.NeuralNetwork
    public synchronized void release() {
        if (this.mNativeHandle == 0) {
            throw SnpeError.getInstance().getIllegalError(SnpeError.Error.ILLEGAL_NETWORK_RELEASE);
        }
        doRelease();
    }
}
