package io.github.winfordguo.flutter_pytorch_lite;

import io.flutter.embedding.engine.plugins.FlutterPlugin;
import io.flutter.plugin.common.BinaryMessenger;
import io.flutter.plugin.common.MethodCall;
import io.flutter.plugin.common.MethodChannel;
import io.flutter.plugin.common.StandardMethodCodec;
import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.pytorch.DType;
import org.pytorch.IValue;
import org.pytorch.LiteModuleLoader;
import org.pytorch.MemoryFormat;
import org.pytorch.Module;
import org.pytorch.Tensor;

/* loaded from: classes2.dex */
public class FlutterPytorchLitePlugin implements FlutterPlugin, MethodChannel.MethodCallHandler {
    static final /* synthetic */ boolean $assertionsDisabled = false;
    private static final int TYPE_CODE_BOOL = 3;
    private static final int TYPE_CODE_BOOL_LIST = 8;
    private static final int TYPE_CODE_DICT_LONG_KEY = 14;
    private static final int TYPE_CODE_DICT_STRING_KEY = 13;
    private static final int TYPE_CODE_DOUBLE = 5;
    private static final int TYPE_CODE_DOUBLE_LIST = 10;
    private static final int TYPE_CODE_LIST = 12;
    private static final int TYPE_CODE_LONG = 4;
    private static final int TYPE_CODE_LONG_LIST = 9;
    private static final int TYPE_CODE_NULL = 1;
    private static final int TYPE_CODE_STRING = 6;
    private static final int TYPE_CODE_TENSOR = 2;
    private static final int TYPE_CODE_TENSOR_LIST = 11;
    private static final int TYPE_CODE_TUPLE = 7;
    private static final ConcurrentHashMap<Integer, Module> mModules = new ConcurrentHashMap<>();
    private MethodChannel channel;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: io.github.winfordguo.flutter_pytorch_lite.FlutterPytorchLitePlugin$1, reason: invalid class name */
    /* loaded from: classes2.dex */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$pytorch$DType;

        static {
            int[] iArr = new int[DType.values().length];
            $SwitchMap$org$pytorch$DType = iArr;
            try {
                iArr[DType.UINT8.ordinal()] = 1;
            } catch (NoSuchFieldError unused) {
            }
            try {
                $SwitchMap$org$pytorch$DType[DType.INT8.ordinal()] = 2;
            } catch (NoSuchFieldError unused2) {
            }
            try {
                $SwitchMap$org$pytorch$DType[DType.INT32.ordinal()] = 3;
            } catch (NoSuchFieldError unused3) {
            }
            try {
                $SwitchMap$org$pytorch$DType[DType.FLOAT32.ordinal()] = 4;
            } catch (NoSuchFieldError unused4) {
            }
            try {
                $SwitchMap$org$pytorch$DType[DType.INT64.ordinal()] = 5;
            } catch (NoSuchFieldError unused5) {
            }
            try {
                $SwitchMap$org$pytorch$DType[DType.FLOAT64.ordinal()] = 6;
            } catch (NoSuchFieldError unused6) {
            }
        }
    }

    public static <T> int getJniCode(T t) {
        try {
            Field declaredField = t.getClass().getDeclaredField("jniCode");
            declaredField.setAccessible(true);
            return declaredField.getInt(t);
        } catch (Exception unused) {
            return 0;
        }
    }

    public static int getTypeCode(IValue iValue) {
        try {
            Field declaredField = iValue.getClass().getDeclaredField("mTypeCode");
            declaredField.setAccessible(true);
            return declaredField.getInt(iValue);
        } catch (Exception unused) {
            return 0;
        }
    }

    public static DType parseDType(Integer num) {
        if (num != null) {
            for (DType dType : DType.values()) {
                if (getJniCode(dType) == num.intValue()) {
                    return dType;
                }
            }
        }
        throw new IllegalArgumentException(num + " is not a valid code for DType.");
    }

    public static MemoryFormat parseMemoryFormat(Integer num) {
        if (num != null) {
            for (MemoryFormat memoryFormat : MemoryFormat.values()) {
                if (getJniCode(memoryFormat) == num.intValue()) {
                    return memoryFormat;
                }
            }
        }
        throw new IllegalArgumentException(num + " is not a valid code for MemoryFormat.");
    }

    /* JADX WARN: Can't fix incorrect switch cases order, some code will duplicate */
    /* JADX WARN: Code restructure failed: missing block: B:42:0x0135, code lost:
    
        return r1;
     */
    /*
        Code decompiled incorrectly, please refer to instructions dump.
        To view partially-correct add '--show-bad-code' argument
    */
    java.util.HashMap<java.lang.String, java.lang.Object> iValueToMap(org.pytorch.IValue r7) {
        /*
            Method dump skipped, instructions count: 340
            To view this dump add '--comments-level debug' option
        */
        throw new UnsupportedOperationException("Method not decompiled: io.github.winfordguo.flutter_pytorch_lite.FlutterPytorchLitePlugin.iValueToMap(org.pytorch.IValue):java.util.HashMap");
    }

    IValue mapToIValue(HashMap<String, Object> hashMap) {
        Integer num = (Integer) hashMap.get("typeCode");
        if (num.intValue() == 1) {
            return IValue.optionalNull();
        }
        Object obj = hashMap.get("data");
        int i = 0;
        switch (num.intValue()) {
            case 2:
                return IValue.from(mapToTensor((HashMap) obj));
            case 3:
                return IValue.from(((Boolean) obj).booleanValue());
            case 4:
                return IValue.from(((Long) obj).longValue());
            case 5:
                return IValue.from(((Double) obj).doubleValue());
            case 6:
                return IValue.from((String) obj);
            case 7:
                ArrayList arrayList = (ArrayList) obj;
                int size = arrayList.size();
                IValue[] iValueArr = new IValue[size];
                while (i < size) {
                    iValueArr[i] = mapToIValue((HashMap) arrayList.get(i));
                    i++;
                }
                return IValue.tupleFrom(iValueArr);
            case 8:
                ArrayList arrayList2 = (ArrayList) obj;
                int size2 = arrayList2.size();
                boolean[] zArr = new boolean[size2];
                while (i < size2) {
                    zArr[i] = ((Boolean) arrayList2.get(i)).booleanValue();
                    i++;
                }
                return IValue.listFrom(zArr);
            case 9:
                return IValue.listFrom((long[]) obj);
            case 10:
                return IValue.listFrom((double[]) obj);
            case 11:
                ArrayList arrayList3 = (ArrayList) obj;
                int size3 = arrayList3.size();
                Tensor[] tensorArr = new Tensor[size3];
                while (i < size3) {
                    tensorArr[i] = mapToTensor((HashMap) arrayList3.get(i));
                    i++;
                }
                return IValue.listFrom(tensorArr);
            case 12:
                ArrayList arrayList4 = (ArrayList) obj;
                int size4 = arrayList4.size();
                IValue[] iValueArr2 = new IValue[size4];
                while (i < size4) {
                    iValueArr2[i] = mapToIValue((HashMap) arrayList4.get(i));
                    i++;
                }
                return IValue.listFrom(iValueArr2);
            case 13:
                HashMap hashMap2 = (HashMap) obj;
                HashMap hashMap3 = new HashMap();
                for (String str : hashMap2.keySet()) {
                    hashMap3.put(str, mapToIValue((HashMap) hashMap2.get(str)));
                }
                return IValue.dictStringKeyFrom(hashMap3);
            case 14:
                HashMap hashMap4 = (HashMap) obj;
                HashMap hashMap5 = new HashMap();
                for (Long l : hashMap4.keySet()) {
                    hashMap5.put(l, mapToIValue((HashMap) hashMap4.get(l)));
                }
                return IValue.dictLongKeyFrom(hashMap5);
            default:
                throw new IllegalArgumentException("Map to IValue error.");
        }
    }

    Tensor mapToTensor(HashMap<String, Object> hashMap) {
        long[] jArr = (long[]) hashMap.get("shape");
        MemoryFormat parseMemoryFormat = parseMemoryFormat((Integer) hashMap.get("memoryFormat"));
        DType parseDType = parseDType((Integer) hashMap.get("dtype"));
        Object obj = hashMap.get("data");
        switch (AnonymousClass1.$SwitchMap$org$pytorch$DType[parseDType.ordinal()]) {
            case 1:
                return Tensor.fromBlobUnsigned((byte[]) obj, jArr, parseMemoryFormat);
            case 2:
                return Tensor.fromBlob((byte[]) obj, jArr, parseMemoryFormat);
            case 3:
                return Tensor.fromBlob((int[]) obj, jArr, parseMemoryFormat);
            case 4:
                return Tensor.fromBlob((float[]) obj, jArr, parseMemoryFormat);
            case 5:
                return Tensor.fromBlob((long[]) obj, jArr, parseMemoryFormat);
            case 6:
                return Tensor.fromBlob((double[]) obj, jArr, parseMemoryFormat);
            default:
                throw new IllegalArgumentException("Map to Tensor error.");
        }
    }

    @Override // io.flutter.embedding.engine.plugins.FlutterPlugin
    public void onAttachedToEngine(FlutterPlugin.FlutterPluginBinding flutterPluginBinding) {
        BinaryMessenger binaryMessenger = flutterPluginBinding.getBinaryMessenger();
        MethodChannel methodChannel = new MethodChannel(binaryMessenger, "flutter_pytorch_lite", StandardMethodCodec.INSTANCE, binaryMessenger.makeBackgroundTaskQueue());
        this.channel = methodChannel;
        methodChannel.setMethodCallHandler(this);
    }

    @Override // io.flutter.embedding.engine.plugins.FlutterPlugin
    public void onDetachedFromEngine(FlutterPlugin.FlutterPluginBinding flutterPluginBinding) {
        this.channel.setMethodCallHandler(null);
        try {
            Iterator<Map.Entry<Integer, Module>> it = mModules.entrySet().iterator();
            while (it.hasNext()) {
                it.next().getValue().destroy();
            }
            mModules.clear();
        } catch (Exception unused) {
        }
    }

    @Override // io.flutter.plugin.common.MethodChannel.MethodCallHandler
    public void onMethodCall(MethodCall methodCall, MethodChannel.Result result) {
        String str = methodCall.method;
        str.hashCode();
        char c = 65535;
        switch (str.hashCode()) {
            case -677145915:
                if (str.equals("forward")) {
                    c = 0;
                    break;
                }
                break;
            case 3327206:
                if (str.equals("load")) {
                    c = 1;
                    break;
                }
                break;
            case 1557372922:
                if (str.equals("destroy")) {
                    c = 2;
                    break;
                }
                break;
        }
        switch (c) {
            case 0:
                try {
                    Module module = mModules.get((Integer) methodCall.argument("moduleId"));
                    List list = (List) methodCall.argument("inputs");
                    int size = list.size();
                    IValue[] iValueArr = new IValue[size];
                    for (int i = 0; i < size; i++) {
                        iValueArr[i] = mapToIValue((HashMap) list.get(i));
                    }
                    result.success(iValueToMap(module.forward(iValueArr)));
                    return;
                } catch (Exception e) {
                    result.error("forwardError", "Pytorch lite module forward error.", e);
                    return;
                }
            case 1:
                try {
                    Module load = LiteModuleLoader.load((String) methodCall.argument("filePath"));
                    mModules.put(Integer.valueOf(load.hashCode()), load);
                    result.success(Integer.valueOf(load.hashCode()));
                    return;
                } catch (Exception e2) {
                    result.error("loadError", "Pytorch lite load module error", e2);
                    return;
                }
            case 2:
                Module remove = mModules.remove((Integer) methodCall.argument("moduleId"));
                if (remove != null) {
                    remove.destroy();
                }
                result.success(null);
                return;
            default:
                result.notImplemented();
                return;
        }
    }

    /* JADX WARN: Can't fix incorrect switch cases order, some code will duplicate */
    /* JADX WARN: Code restructure failed: missing block: B:9:0x0071, code lost:
    
        return r0;
     */
    /*
        Code decompiled incorrectly, please refer to instructions dump.
        To view partially-correct add '--show-bad-code' argument
    */
    java.util.HashMap<java.lang.String, java.lang.Object> tensorToMap(org.pytorch.Tensor r4) {
        /*
            r3 = this;
            java.util.HashMap r0 = new java.util.HashMap
            r0.<init>()
            java.lang.String r1 = "shape"
            long[] r2 = r4.shape()
            r0.put(r1, r2)
            org.pytorch.MemoryFormat r1 = r4.memoryFormat()
            int r1 = getJniCode(r1)
            java.lang.Integer r1 = java.lang.Integer.valueOf(r1)
            java.lang.String r2 = "memoryFormat"
            r0.put(r2, r1)
            org.pytorch.DType r1 = r4.dtype()
            int r1 = getJniCode(r1)
            java.lang.Integer r1 = java.lang.Integer.valueOf(r1)
            java.lang.String r2 = "dtype"
            r0.put(r2, r1)
            int[] r1 = io.github.winfordguo.flutter_pytorch_lite.FlutterPytorchLitePlugin.AnonymousClass1.$SwitchMap$org$pytorch$DType
            org.pytorch.DType r2 = r4.dtype()
            int r2 = r2.ordinal()
            r1 = r1[r2]
            java.lang.String r2 = "data"
            switch(r1) {
                case 1: goto L6a;
                case 2: goto L62;
                case 3: goto L5a;
                case 4: goto L52;
                case 5: goto L4a;
                case 6: goto L42;
                default: goto L41;
            }
        L41:
            goto L71
        L42:
            double[] r4 = r4.getDataAsDoubleArray()
            r0.put(r2, r4)
            goto L71
        L4a:
            long[] r4 = r4.getDataAsLongArray()
            r0.put(r2, r4)
            goto L71
        L52:
            float[] r4 = r4.getDataAsFloatArray()
            r0.put(r2, r4)
            goto L71
        L5a:
            int[] r4 = r4.getDataAsIntArray()
            r0.put(r2, r4)
            goto L71
        L62:
            byte[] r4 = r4.getDataAsByteArray()
            r0.put(r2, r4)
            goto L71
        L6a:
            byte[] r4 = r4.getDataAsUnsignedByteArray()
            r0.put(r2, r4)
        L71:
            return r0
        */
        throw new UnsupportedOperationException("Method not decompiled: io.github.winfordguo.flutter_pytorch_lite.FlutterPytorchLitePlugin.tensorToMap(org.pytorch.Tensor):java.util.HashMap");
    }
}
