/*
 * Decompiled with CFR 0.152.
 */
package io.kinference.operators.layer.recurrent.lstm;

import io.kinference.data.tensors.Tensor;
import io.kinference.data.tensors.TensorExtensionsKt;
import io.kinference.ndarray.Strides;
import io.kinference.ndarray.arrays.MutableNDArray;
import io.kinference.ndarray.arrays.NDArray;
import io.kinference.ndarray.extensions.ArrayFactoriesKt;
import io.kinference.ndarray.extensions.NDArrayExtensionsKt;
import io.kinference.ndarray.extensions.SplitKt;
import io.kinference.operators.layer.recurrent.lstm.GatesData;
import io.kinference.operators.layer.recurrent.lstm.LSTMBase;
import io.kinference.operators.layer.recurrent.lstm.LSTMData;
import io.kinference.operators.layer.recurrent.lstm.LSTMLayer;
import io.kinference.primitives.types.DataType;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import kotlin.Metadata;
import kotlin.collections.CollectionsKt;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

@Metadata(mv={1, 1, 16}, bv={1, 0, 3}, k=1, d1={"\u0000F\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0000\n\u0002\u0010\b\n\u0000\n\u0002\u0010 \n\u0002\u0010\u000e\n\u0002\b\u0003\n\u0002\u0018\u0002\n\u0002\b\b\n\u0002\u0018\u0002\n\u0000\n\u0002\u0018\u0002\n\u0000\n\u0002\u0010\u0015\n\u0000\n\u0002\u0018\u0002\n\u0002\b\u0005\n\u0002\u0010\u0002\n\u0002\b\u0007\u0018\u00002\u00020\u0001B#\u0012\u0006\u0010\u0002\u001a\u00020\u0003\u0012\f\u0010\u0004\u001a\b\u0012\u0004\u0012\u00020\u00060\u0005\u0012\u0006\u0010\u0007\u001a\u00020\u0006\u00a2\u0006\u0002\u0010\bJ4\u0010\u0012\u001a\b\u0012\u0004\u0012\u00020\u00130\u00052\f\u0010\u0014\u001a\b\u0012\u0004\u0012\u00020\u00150\u00052\u0006\u0010\u0016\u001a\u00020\u00172\u0006\u0010\u0018\u001a\u00020\u00192\u0006\u0010\u001a\u001a\u00020\u0003H\u0016J\u0018\u0010\u001b\u001a\u00020\u00192\u0006\u0010\u001c\u001a\u00020\u00132\u0006\u0010\u001d\u001a\u00020\u0013H\u0002J@\u0010\u001e\u001a\u00020\u001f2\u0006\u0010 \u001a\u00020\u00132\u0006\u0010!\u001a\u00020\u00132\b\u0010\"\u001a\u0004\u0018\u00010\u00132\b\u0010#\u001a\u0004\u0018\u00010\u00132\b\u0010$\u001a\u0004\u0018\u00010\u00132\b\u0010%\u001a\u0004\u0018\u00010\u0013H\u0014R\u001c\u0010\t\u001a\u0004\u0018\u00010\nX\u0086\u000e\u00a2\u0006\u000e\n\u0000\u001a\u0004\b\u000b\u0010\f\"\u0004\b\r\u0010\u000eR\u001c\u0010\u000f\u001a\u0004\u0018\u00010\nX\u0086\u000e\u00a2\u0006\u000e\n\u0000\u001a\u0004\b\u0010\u0010\f\"\u0004\b\u0011\u0010\u000e\u00a8\u0006&"}, d2={"Lio/kinference/operators/layer/recurrent/lstm/BiLSTMLayer;", "Lio/kinference/operators/layer/recurrent/lstm/LSTMBase;", "hiddenSize", "", "activations", "", "", "direction", "(ILjava/util/List;Ljava/lang/String;)V", "forwardLstmData", "Lio/kinference/operators/layer/recurrent/lstm/LSTMData;", "getForwardLstmData", "()Lio/kinference/operators/layer/recurrent/lstm/LSTMData;", "setForwardLstmData", "(Lio/kinference/operators/layer/recurrent/lstm/LSTMData;)V", "reverseLstmData", "getReverseLstmData", "setReverseLstmData", "apply", "Lio/kinference/data/tensors/Tensor;", "inputs", "Lio/kinference/ndarray/arrays/NDArray;", "sequenceLens", "", "outputArray", "Lio/kinference/ndarray/arrays/MutableNDArray;", "startOffset", "concatLasts", "forward", "reverse", "parseTempInputs", "", "weights", "recurrentWeights", "bias", "initialOutput", "initialCellState", "peepholes", "inference"})
public final class BiLSTMLayer
extends LSTMBase {
    @Nullable
    private LSTMData forwardLstmData;
    @Nullable
    private LSTMData reverseLstmData;

    @Nullable
    public final LSTMData getForwardLstmData() {
        return this.forwardLstmData;
    }

    public final void setForwardLstmData(@Nullable LSTMData lSTMData) {
        this.forwardLstmData = lSTMData;
    }

    @Nullable
    public final LSTMData getReverseLstmData() {
        return this.reverseLstmData;
    }

    public final void setReverseLstmData(@Nullable LSTMData lSTMData) {
        this.reverseLstmData = lSTMData;
    }

    /*
     * WARNING - void declaration
     */
    @Override
    @NotNull
    public List<Tensor> apply(@NotNull List<? extends NDArray> inputs, @NotNull int[] sequenceLens, @NotNull MutableNDArray outputArray, int startOffset) {
        void reverseLastOutput;
        void forwardLastOutput;
        void output;
        List<Tensor> list;
        Object object;
        Intrinsics.checkParameterIsNotNull(inputs, (String)"inputs");
        Intrinsics.checkParameterIsNotNull((Object)sequenceLens, (String)"sequenceLens");
        Intrinsics.checkParameterIsNotNull((Object)outputArray, (String)"outputArray");
        int n = this.getHiddenSize();
        List<String> list2 = this.getActivations().subList(0, 3);
        LSTMData lSTMData = this.forwardLstmData;
        if (lSTMData == null) {
            Intrinsics.throwNpe();
        }
        Integer n2 = this.getSeqLength();
        if (n2 == null) {
            Intrinsics.throwNpe();
        }
        int n3 = n2;
        Integer n4 = this.getBatchSize();
        if (n4 == null) {
            Intrinsics.throwNpe();
        }
        int n5 = n4;
        DataType dataType = this.getType();
        if (dataType == null) {
            Intrinsics.throwNpe();
        }
        LSTMLayer forwardLayer = LSTMLayer.Companion.create(n, list2, "forward", lSTMData, n3, n5, TensorExtensionsKt.resolveProtoDataType(dataType));
        int n6 = this.getHiddenSize();
        List<String> list3 = this.getActivations().subList(3, 6);
        LSTMData lSTMData2 = this.reverseLstmData;
        if (lSTMData2 == null) {
            Intrinsics.throwNpe();
        }
        Integer n7 = this.getSeqLength();
        if (n7 == null) {
            Intrinsics.throwNpe();
        }
        int n8 = n7;
        Integer n9 = this.getBatchSize();
        if (n9 == null) {
            Intrinsics.throwNpe();
        }
        int n10 = n9;
        DataType dataType2 = this.getType();
        if (dataType2 == null) {
            Intrinsics.throwNpe();
        }
        LSTMLayer reverseLayer = LSTMLayer.Companion.create(n6, list3, "reverse", lSTMData2, n8, n10, TensorExtensionsKt.resolveProtoDataType(dataType2));
        Object object2 = object = forwardLayer.apply(inputs, sequenceLens, outputArray, startOffset);
        boolean bl = false;
        Tensor tensor = object2.get(1);
        object2 = object;
        bl = false;
        Tensor forwardLastCellState = object2.get(2);
        Integer n11 = this.getBatchSize();
        if (n11 == null) {
            Intrinsics.throwNpe();
        }
        List<Tensor> list4 = list = reverseLayer.apply(inputs, sequenceLens, outputArray, startOffset + n11 * this.getHiddenSize());
        boolean bl2 = false;
        object = list4.get(0);
        list4 = list;
        bl2 = false;
        object2 = list4.get(1);
        list4 = list;
        bl2 = false;
        Tensor reverseLastCellState = list4.get(2);
        return CollectionsKt.listOf((Object[])new Tensor[]{output, TensorExtensionsKt.asTensor$default((NDArray)this.concatLasts((Tensor)forwardLastOutput, (Tensor)reverseLastOutput), null, 1, null), TensorExtensionsKt.asTensor$default((NDArray)this.concatLasts(forwardLastCellState, reverseLastCellState), null, 1, null)});
    }

    private final MutableNDArray concatLasts(Tensor forward, Tensor reverse2) {
        int[] nArray = forward.getData().getShape();
        boolean bl = false;
        int[] nArray2 = Arrays.copyOf(nArray, nArray.length);
        Intrinsics.checkExpressionValueIsNotNull((Object)nArray2, (String)"java.util.Arrays.copyOf(this, size)");
        int[] newShape = nArray2;
        newShape[0] = 2;
        Strides newStrides = new Strides(newShape);
        DataType dataType = this.getType();
        if (dataType == null) {
            Intrinsics.throwNpe();
        }
        MutableNDArray newArray = ArrayFactoriesKt.allocateNDArray((DataType)dataType, (Strides)newStrides);
        MutableNDArray.DefaultImpls.copyFrom$default((MutableNDArray)newArray, (int)0, (NDArray)forward.getData(), (int)0, (int)0, (int)12, null);
        MutableNDArray.DefaultImpls.copyFrom$default((MutableNDArray)newArray, (int)forward.getData().getLinearSize(), (NDArray)reverse2.getData(), (int)0, (int)0, (int)12, null);
        return newArray;
    }

    /*
     * WARNING - void declaration
     */
    @Override
    protected void parseTempInputs(@NotNull Tensor weights, @NotNull Tensor recurrentWeights, @Nullable Tensor bias, @Nullable Tensor initialOutput, @Nullable Tensor initialCellState, @Nullable Tensor peepholes) {
        MutableNDArray forward;
        MutableNDArray reverse2;
        List list;
        MutableNDArray forwardParsedWeights;
        Object object;
        Intrinsics.checkParameterIsNotNull((Object)weights, (String)"weights");
        Intrinsics.checkParameterIsNotNull((Object)recurrentWeights, (String)"recurrentWeights");
        if (this.forwardLstmData == null || this.reverseLstmData == null) {
            void forwardParsedRecWeights;
            void $this$mapTo$iv$iv;
            GatesData gatesData;
            Collection collection;
            void $this$mapTo$iv$iv2;
            Iterable $this$map$iv = SplitKt.splitWithAxis$default((NDArray)weights.getData(), (int)2, (int)0, (boolean)false, (int)6, null);
            boolean $i$f$map = false;
            Iterable iterable = $this$map$iv;
            Collection destination$iv$iv = new ArrayList(CollectionsKt.collectionSizeOrDefault((Iterable)$this$map$iv, (int)10));
            boolean $i$f$mapTo22 = false;
            for (Object item$iv$iv : $this$mapTo$iv$iv2) {
                void it;
                MutableNDArray mutableNDArray = (MutableNDArray)item$iv$iv;
                collection = destination$iv$iv;
                boolean bl = false;
                gatesData = GatesData.Companion.createWeights((MutableNDArray)it);
                collection.add(gatesData);
            }
            object = (List)destination$iv$iv;
            $this$map$iv = object;
            boolean bl = false;
            GatesData gatesData2 = (GatesData)$this$map$iv.get(0);
            $this$map$iv = object;
            bl = false;
            GatesData reverseParsedWeights = (GatesData)$this$map$iv.get(1);
            Iterable $this$map$iv2 = SplitKt.splitWithAxis$default((NDArray)recurrentWeights.getData(), (int)2, (int)0, (boolean)false, (int)6, null);
            boolean $i$f$map2 = false;
            Iterable $i$f$mapTo22 = $this$map$iv2;
            Collection destination$iv$iv2 = new ArrayList(CollectionsKt.collectionSizeOrDefault((Iterable)$this$map$iv2, (int)10));
            boolean $i$f$mapTo = false;
            for (Object item$iv$iv : $this$mapTo$iv$iv) {
                void it;
                MutableNDArray mutableNDArray = (MutableNDArray)item$iv$iv;
                collection = destination$iv$iv2;
                boolean bl2 = false;
                gatesData = GatesData.Companion.createWeights((MutableNDArray)it);
                collection.add(gatesData);
            }
            List list2 = (List)destination$iv$iv2;
            iterable = list2;
            boolean bl3 = false;
            object = (GatesData)iterable.get(0);
            iterable = list2;
            bl3 = false;
            GatesData reverseParsedRecWeights = (GatesData)iterable.get(1);
            DataType dataType = this.getType();
            if (dataType == null) {
                Intrinsics.throwNpe();
            }
            this.forwardLstmData = new LSTMData(dataType, (GatesData)forwardParsedWeights, (GatesData)forwardParsedRecWeights, null, null, null, null, 120, null);
            DataType dataType2 = this.getType();
            if (dataType2 == null) {
                Intrinsics.throwNpe();
            }
            this.reverseLstmData = new LSTMData(dataType2, reverseParsedWeights, reverseParsedRecWeights, null, null, null, null, 120, null);
            this.setWeights(weights.getData());
            this.setRecurrentWeights(recurrentWeights.getData());
            this.setBias(null);
            this.setInitialOutput(null);
            this.setInitialCellState(null);
            this.setPeepholes(null);
        }
        if (weights.getData() != this.getWeights()) {
            object = SplitKt.splitWithAxis$default((NDArray)weights.getData(), (int)2, (int)0, (boolean)false, (int)6, null);
            list = object;
            boolean bl = false;
            forwardParsedWeights = (MutableNDArray)list.get(0);
            list = object;
            bl = false;
            reverse2 = (MutableNDArray)list.get(1);
            LSTMData lSTMData = this.forwardLstmData;
            if (lSTMData == null) {
                Intrinsics.throwNpe();
            }
            this.forwardLstmData = lSTMData.updateWeights(GatesData.Companion.createWeights(forward));
            LSTMData lSTMData2 = this.reverseLstmData;
            if (lSTMData2 == null) {
                Intrinsics.throwNpe();
            }
            this.reverseLstmData = lSTMData2.updateWeights(GatesData.Companion.createWeights(reverse2));
            this.setWeights(weights.getData());
        }
        if (recurrentWeights.getData() != this.getRecurrentWeights()) {
            object = SplitKt.splitWithAxis$default((NDArray)recurrentWeights.getData(), (int)2, (int)0, (boolean)false, (int)6, null);
            list = object;
            boolean bl = false;
            forward = (MutableNDArray)list.get(0);
            list = object;
            bl = false;
            reverse2 = (MutableNDArray)list.get(1);
            LSTMData lSTMData = this.forwardLstmData;
            if (lSTMData == null) {
                Intrinsics.throwNpe();
            }
            this.forwardLstmData = lSTMData.updateRecurrentWeights(GatesData.Companion.createWeights(forward));
            LSTMData lSTMData3 = this.reverseLstmData;
            if (lSTMData3 == null) {
                Intrinsics.throwNpe();
            }
            this.reverseLstmData = lSTMData3.updateRecurrentWeights(GatesData.Companion.createWeights(reverse2));
            this.setRecurrentWeights(recurrentWeights.getData());
        }
        if (bias != null && bias.getData() != this.getBias()) {
            object = SplitKt.splitWithAxis$default((NDArray)bias.getData(), (int)2, (int)0, (boolean)false, (int)6, null);
            list = object;
            boolean bl = false;
            forward = (MutableNDArray)list.get(0);
            list = object;
            bl = false;
            reverse2 = (MutableNDArray)list.get(1);
            LSTMData lSTMData = this.forwardLstmData;
            if (lSTMData == null) {
                Intrinsics.throwNpe();
            }
            this.forwardLstmData = lSTMData.updateBias(GatesData.Companion.createBias(forward));
            LSTMData lSTMData4 = this.reverseLstmData;
            if (lSTMData4 == null) {
                Intrinsics.throwNpe();
            }
            this.reverseLstmData = lSTMData4.updateBias(GatesData.Companion.createBias(reverse2));
            this.setBias(bias.getData());
        }
        if (initialOutput != null && initialOutput.getData() != this.getInitialOutput()) {
            object = SplitKt.splitWithAxis$default((NDArray)initialOutput.getData(), (int)2, (int)0, (boolean)false, (int)6, null);
            list = object;
            boolean bl = false;
            forward = (MutableNDArray)list.get(0);
            list = object;
            bl = false;
            reverse2 = (MutableNDArray)list.get(1);
            LSTMData lSTMData = this.forwardLstmData;
            if (lSTMData == null) {
                Intrinsics.throwNpe();
            }
            NDArray nDArray = (NDArray)NDArrayExtensionsKt.squeeze((MutableNDArray)forward, (int[])new int[]{0});
            Integer n = this.getBatchSize();
            if (n == null) {
                Intrinsics.throwNpe();
            }
            this.forwardLstmData = lSTMData.updateInitialOutput(SplitKt.splitWithAxis$default((NDArray)nDArray, (int)n, (int)0, (boolean)false, (int)6, null));
            LSTMData lSTMData5 = this.reverseLstmData;
            if (lSTMData5 == null) {
                Intrinsics.throwNpe();
            }
            NDArray nDArray2 = (NDArray)NDArrayExtensionsKt.squeeze((MutableNDArray)reverse2, (int[])new int[]{0});
            Integer n2 = this.getBatchSize();
            if (n2 == null) {
                Intrinsics.throwNpe();
            }
            this.reverseLstmData = lSTMData5.updateInitialOutput(SplitKt.splitWithAxis$default((NDArray)nDArray2, (int)n2, (int)0, (boolean)false, (int)6, null));
            this.setInitialOutput(initialOutput.getData());
        }
        if (initialCellState != null && initialCellState.getData() != this.getInitialCellState()) {
            object = SplitKt.splitWithAxis$default((NDArray)initialCellState.getData(), (int)2, (int)0, (boolean)false, (int)6, null);
            list = object;
            boolean bl = false;
            forward = (MutableNDArray)list.get(0);
            list = object;
            bl = false;
            reverse2 = (MutableNDArray)list.get(1);
            LSTMData lSTMData = this.forwardLstmData;
            if (lSTMData == null) {
                Intrinsics.throwNpe();
            }
            NDArray nDArray = (NDArray)NDArrayExtensionsKt.squeeze((MutableNDArray)forward, (int[])new int[]{0});
            Integer n = this.getBatchSize();
            if (n == null) {
                Intrinsics.throwNpe();
            }
            this.forwardLstmData = lSTMData.updateInitialCellGate(SplitKt.splitWithAxis$default((NDArray)nDArray, (int)n, (int)0, (boolean)false, (int)6, null));
            LSTMData lSTMData6 = this.reverseLstmData;
            if (lSTMData6 == null) {
                Intrinsics.throwNpe();
            }
            NDArray nDArray3 = (NDArray)NDArrayExtensionsKt.squeeze((MutableNDArray)reverse2, (int[])new int[]{0});
            Integer n3 = this.getBatchSize();
            if (n3 == null) {
                Intrinsics.throwNpe();
            }
            this.reverseLstmData = lSTMData6.updateInitialCellGate(SplitKt.splitWithAxis$default((NDArray)nDArray3, (int)n3, (int)0, (boolean)false, (int)6, null));
            this.setInitialCellState(initialCellState.getData());
        }
        if (peepholes != null && peepholes.getData() != this.getPeepholes()) {
            object = SplitKt.splitWithAxis$default((NDArray)peepholes.getData(), (int)2, (int)0, (boolean)false, (int)6, null);
            list = object;
            boolean bl = false;
            forward = (MutableNDArray)list.get(0);
            list = object;
            bl = false;
            reverse2 = (MutableNDArray)list.get(1);
            LSTMData lSTMData = this.forwardLstmData;
            if (lSTMData == null) {
                Intrinsics.throwNpe();
            }
            this.forwardLstmData = lSTMData.updatePeepholes(GatesData.Companion.createPeepholes(forward));
            LSTMData lSTMData7 = this.reverseLstmData;
            if (lSTMData7 == null) {
                Intrinsics.throwNpe();
            }
            this.reverseLstmData = lSTMData7.updatePeepholes(GatesData.Companion.createPeepholes(reverse2));
            this.setPeepholes(peepholes.getData());
        }
    }

    public BiLSTMLayer(int hiddenSize, @NotNull List<String> activations, @NotNull String direction) {
        Intrinsics.checkParameterIsNotNull(activations, (String)"activations");
        Intrinsics.checkParameterIsNotNull((Object)direction, (String)"direction");
        super(hiddenSize, activations, direction);
        boolean bl = Intrinsics.areEqual((Object)direction, (Object)"bidirectional");
        boolean bl2 = false;
        boolean bl3 = false;
        bl3 = false;
        boolean bl4 = false;
        if (!bl) {
            boolean bl5 = false;
            String string = "Failed requirement.";
            throw (Throwable)new IllegalArgumentException(string.toString());
        }
        bl = activations.size() == 6;
        bl2 = false;
        bl3 = false;
        bl3 = false;
        bl4 = false;
        if (!bl) {
            boolean bl6 = false;
            String string = "Failed requirement.";
            throw (Throwable)new IllegalArgumentException(string.toString());
        }
    }
}

