/*
 * Decompiled with CFR 0.152.
 */
package io.kinference.operators.layer.recurrent.lstm;

import io.kinference.data.tensors.Tensor;
import io.kinference.ndarray.Strides;
import io.kinference.ndarray.arrays.IntNDArray;
import io.kinference.ndarray.arrays.MutableNDArray;
import io.kinference.ndarray.arrays.NDArray;
import io.kinference.ndarray.arrays.pointers.IntPointer;
import io.kinference.ndarray.arrays.tiled.IntTiledArray;
import io.kinference.ndarray.extensions.ArrayFactoriesKt;
import io.kinference.ndarray.extensions.SplitKt;
import io.kinference.operators.layer.recurrent.RecurrentLayer;
import io.kinference.primitives.types.DataType;
import java.util.Collection;
import java.util.List;
import kotlin.Metadata;
import kotlin.TypeCastException;
import kotlin.collections.CollectionsKt;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

@Metadata(mv={1, 1, 16}, bv={1, 0, 3}, k=1, d1={"\u0000H\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0000\n\u0002\u0010\b\n\u0000\n\u0002\u0010 \n\u0002\u0010\u000e\n\u0002\b\t\n\u0002\u0018\u0002\n\u0002\b\u0014\n\u0002\u0018\u0002\n\u0002\b\b\n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0010\u0015\n\u0000\n\u0002\u0018\u0002\n\u0002\b\u0005\n\u0002\u0010\u0002\n\u0000\b&\u0018\u00002\u00020\u0001B#\u0012\u0006\u0010\u0002\u001a\u00020\u0003\u0012\f\u0010\u0004\u001a\b\u0012\u0004\u0012\u00020\u00060\u0005\u0012\u0006\u0010\u0007\u001a\u00020\u0006\u00a2\u0006\u0002\u0010\bJ \u0010-\u001a\n\u0012\u0006\u0012\u0004\u0018\u00010.0\u00052\u000e\u0010/\u001a\n\u0012\u0006\u0012\u0004\u0018\u00010.0\u0005H\u0016J4\u0010-\u001a\b\u0012\u0004\u0012\u00020.0\u00052\f\u00100\u001a\b\u0012\u0004\u0012\u00020\u00100\u00052\u0006\u00101\u001a\u0002022\u0006\u00103\u001a\u0002042\u0006\u00105\u001a\u00020\u0003H&J\u0016\u00106\u001a\b\u0012\u0004\u0012\u0002040\u00052\u0006\u00107\u001a\u00020.H\u0002J\u0012\u00108\u001a\u0002022\b\u00107\u001a\u0004\u0018\u00010.H\u0002J@\u00109\u001a\u00020:2\u0006\u0010*\u001a\u00020.2\u0006\u0010\u001e\u001a\u00020.2\b\u0010\u000f\u001a\u0004\u0018\u00010.2\b\u0010\u0018\u001a\u0004\u0018\u00010.2\b\u0010\u0015\u001a\u0004\u0018\u00010.2\b\u0010\u001b\u001a\u0004\u0018\u00010.H$R\u001e\u0010\t\u001a\u0004\u0018\u00010\u0003X\u0084\u000e\u00a2\u0006\u0010\n\u0002\u0010\u000e\u001a\u0004\b\n\u0010\u000b\"\u0004\b\f\u0010\rR\u001c\u0010\u000f\u001a\u0004\u0018\u00010\u0010X\u0084\u000e\u00a2\u0006\u000e\n\u0000\u001a\u0004\b\u0011\u0010\u0012\"\u0004\b\u0013\u0010\u0014R\u001c\u0010\u0015\u001a\u0004\u0018\u00010\u0010X\u0084\u000e\u00a2\u0006\u000e\n\u0000\u001a\u0004\b\u0016\u0010\u0012\"\u0004\b\u0017\u0010\u0014R\u001c\u0010\u0018\u001a\u0004\u0018\u00010\u0010X\u0084\u000e\u00a2\u0006\u000e\n\u0000\u001a\u0004\b\u0019\u0010\u0012\"\u0004\b\u001a\u0010\u0014R\u001c\u0010\u001b\u001a\u0004\u0018\u00010\u0010X\u0084\u000e\u00a2\u0006\u000e\n\u0000\u001a\u0004\b\u001c\u0010\u0012\"\u0004\b\u001d\u0010\u0014R\u001c\u0010\u001e\u001a\u0004\u0018\u00010\u0010X\u0084\u000e\u00a2\u0006\u000e\n\u0000\u001a\u0004\b\u001f\u0010\u0012\"\u0004\b \u0010\u0014R\u001e\u0010!\u001a\u0004\u0018\u00010\u0003X\u0084\u000e\u00a2\u0006\u0010\n\u0002\u0010\u000e\u001a\u0004\b\"\u0010\u000b\"\u0004\b#\u0010\rR\u001c\u0010$\u001a\u0004\u0018\u00010%X\u0084\u000e\u00a2\u0006\u000e\n\u0000\u001a\u0004\b&\u0010'\"\u0004\b(\u0010)R\u001c\u0010*\u001a\u0004\u0018\u00010\u0010X\u0084\u000e\u00a2\u0006\u000e\n\u0000\u001a\u0004\b+\u0010\u0012\"\u0004\b,\u0010\u0014\u00a8\u0006;"}, d2={"Lio/kinference/operators/layer/recurrent/lstm/LSTMBase;", "Lio/kinference/operators/layer/recurrent/RecurrentLayer;", "hiddenSize", "", "activations", "", "", "direction", "(ILjava/util/List;Ljava/lang/String;)V", "batchSize", "getBatchSize", "()Ljava/lang/Integer;", "setBatchSize", "(Ljava/lang/Integer;)V", "Ljava/lang/Integer;", "bias", "Lio/kinference/ndarray/arrays/NDArray;", "getBias", "()Lio/kinference/ndarray/arrays/NDArray;", "setBias", "(Lio/kinference/ndarray/arrays/NDArray;)V", "initialCellState", "getInitialCellState", "setInitialCellState", "initialOutput", "getInitialOutput", "setInitialOutput", "peepholes", "getPeepholes", "setPeepholes", "recurrentWeights", "getRecurrentWeights", "setRecurrentWeights", "seqLength", "getSeqLength", "setSeqLength", "type", "Lio/kinference/primitives/types/DataType;", "getType", "()Lio/kinference/primitives/types/DataType;", "setType", "(Lio/kinference/primitives/types/DataType;)V", "weights", "getWeights", "setWeights", "apply", "Lio/kinference/data/tensors/Tensor;", "inputList", "inputs", "sequenceLens", "", "outputArray", "Lio/kinference/ndarray/arrays/MutableNDArray;", "startOffset", "parseInput", "input", "parseSequenceLength", "parseTempInputs", "", "inference"})
public abstract class LSTMBase
extends RecurrentLayer {
    @Nullable
    private NDArray weights;
    @Nullable
    private NDArray recurrentWeights;
    @Nullable
    private NDArray bias;
    @Nullable
    private NDArray peepholes;
    @Nullable
    private NDArray initialOutput;
    @Nullable
    private NDArray initialCellState;
    @Nullable
    private Integer seqLength;
    @Nullable
    private Integer batchSize;
    @Nullable
    private DataType type;

    @Nullable
    protected final NDArray getWeights() {
        return this.weights;
    }

    protected final void setWeights(@Nullable NDArray nDArray) {
        this.weights = nDArray;
    }

    @Nullable
    protected final NDArray getRecurrentWeights() {
        return this.recurrentWeights;
    }

    protected final void setRecurrentWeights(@Nullable NDArray nDArray) {
        this.recurrentWeights = nDArray;
    }

    @Nullable
    protected final NDArray getBias() {
        return this.bias;
    }

    protected final void setBias(@Nullable NDArray nDArray) {
        this.bias = nDArray;
    }

    @Nullable
    protected final NDArray getPeepholes() {
        return this.peepholes;
    }

    protected final void setPeepholes(@Nullable NDArray nDArray) {
        this.peepholes = nDArray;
    }

    @Nullable
    protected final NDArray getInitialOutput() {
        return this.initialOutput;
    }

    protected final void setInitialOutput(@Nullable NDArray nDArray) {
        this.initialOutput = nDArray;
    }

    @Nullable
    protected final NDArray getInitialCellState() {
        return this.initialCellState;
    }

    protected final void setInitialCellState(@Nullable NDArray nDArray) {
        this.initialCellState = nDArray;
    }

    @Nullable
    protected final Integer getSeqLength() {
        return this.seqLength;
    }

    protected final void setSeqLength(@Nullable Integer n) {
        this.seqLength = n;
    }

    @Nullable
    protected final Integer getBatchSize() {
        return this.batchSize;
    }

    protected final void setBatchSize(@Nullable Integer n) {
        this.batchSize = n;
    }

    @Nullable
    protected final DataType getType() {
        return this.type;
    }

    protected final void setType(@Nullable DataType dataType) {
        this.type = dataType;
    }

    @NotNull
    public abstract List<Tensor> apply(@NotNull List<? extends NDArray> var1, @NotNull int[] var2, @NotNull MutableNDArray var3, int var4);

    @Override
    @NotNull
    public List<Tensor> apply(@NotNull List<Tensor> inputList) {
        boolean bl;
        boolean bl2;
        boolean $i$f$all;
        boolean bl3;
        block15: {
            Intrinsics.checkParameterIsNotNull(inputList, (String)"inputList");
            List list = CollectionsKt.toMutableList((Collection)inputList);
            boolean bl4 = false;
            bl3 = false;
            List it = list;
            boolean bl5 = false;
            int n = 4;
            if (((Collection)it).size() > n) {
                it.remove(4);
            }
            Iterable $this$all$iv = list;
            $i$f$all = false;
            if ($this$all$iv instanceof Collection && ((Collection)$this$all$iv).isEmpty()) {
                bl2 = true;
            } else {
                for (Object element$iv2 : $this$all$iv) {
                    Tensor it2 = (Tensor)element$iv2;
                    bl = false;
                    Tensor tensor = it2;
                    DataType dataType = tensor != null && (tensor = tensor.getData()) != null ? tensor.getType() : null;
                    Tensor tensor2 = inputList.get(0);
                    if (tensor2 == null) {
                        Intrinsics.throwNpe();
                    }
                    if (dataType == tensor2.getData().getType()) continue;
                    bl2 = false;
                    break block15;
                }
                bl2 = true;
            }
        }
        boolean $this$all$iv = bl2;
        $i$f$all = false;
        bl3 = false;
        bl3 = false;
        boolean element$iv2 = false;
        if (!$this$all$iv) {
            boolean it2 = false;
            String element$iv2 = "Failed requirement.";
            throw (Throwable)new IllegalArgumentException(element$iv2.toString());
        }
        Tensor tensor = inputList.get(0);
        if (tensor == null) {
            Intrinsics.throwNpe();
        }
        Tensor input = tensor;
        this.seqLength = input.getData().getShape()[0];
        this.batchSize = input.getData().getShape()[1];
        this.type = input.getData().getType();
        Tensor tensor3 = inputList.get(1);
        if (tensor3 == null) {
            Intrinsics.throwNpe();
        }
        Tensor weights = tensor3;
        Tensor tensor4 = inputList.get(2);
        if (tensor4 == null) {
            Intrinsics.throwNpe();
        }
        Tensor recurrentWeights = tensor4;
        Tensor bias = (Tensor)CollectionsKt.getOrNull(inputList, (int)3);
        Tensor sequenceLens = (Tensor)CollectionsKt.getOrNull(inputList, (int)4);
        if (sequenceLens != null) {
            bl = sequenceLens.getData().getType() == DataType.INT;
            boolean bl6 = false;
            boolean bl7 = false;
            bl7 = false;
            boolean bl8 = false;
            if (!bl) {
                boolean bl9 = false;
                String string = "Failed requirement.";
                throw (Throwable)new IllegalArgumentException(string.toString());
            }
        }
        Tensor initialOutput = (Tensor)CollectionsKt.getOrNull(inputList, (int)5);
        Tensor initialCellState = (Tensor)CollectionsKt.getOrNull(inputList, (int)6);
        Tensor peepholes = (Tensor)CollectionsKt.getOrNull(inputList, (int)7);
        this.parseTempInputs(weights, recurrentWeights, bias, initialOutput, initialCellState, peepholes);
        int[] nArray = new int[4];
        Integer n = this.seqLength;
        if (n == null) {
            Intrinsics.throwNpe();
        }
        nArray[0] = n;
        nArray[1] = 1;
        Integer n2 = this.batchSize;
        if (n2 == null) {
            Intrinsics.throwNpe();
        }
        nArray[2] = n2;
        nArray[3] = this.getHiddenSize();
        int[] outputShape = nArray;
        if (Intrinsics.areEqual((Object)this.getDirection(), (Object)"bidirectional")) {
            outputShape[1] = 2;
        }
        Strides outputStrides = new Strides(outputShape);
        DataType dataType = this.type;
        if (dataType == null) {
            Intrinsics.throwNpe();
        }
        MutableNDArray outputArray = ArrayFactoriesKt.allocateNDArray((DataType)dataType, (Strides)outputStrides);
        return this.apply(this.parseInput(input), this.parseSequenceLength(sequenceLens), outputArray, 0);
    }

    private final List<MutableNDArray> parseInput(Tensor input) {
        return SplitKt.splitParts((NDArray)input.getData(), (int)(input.getData().getShape()[0] * input.getData().getShape()[1]), (Strides)new Strides(new int[]{1, input.getData().getShape()[2]}));
    }

    private final int[] parseSequenceLength(Tensor input) {
        int[] nArray;
        Tensor tensor = input;
        if ((tensor != null ? tensor.getData() : null) == null) {
            Integer n = this.batchSize;
            if (n == null) {
                Intrinsics.throwNpe();
            }
            int n2 = n;
            int[] nArray2 = new int[n2];
            for (int i = 0; i < n2; ++i) {
                int n3;
                int n4 = i;
                int n5 = i;
                int[] nArray3 = nArray2;
                boolean bl = false;
                Integer n6 = this.seqLength;
                if (n6 == null) {
                    Intrinsics.throwNpe();
                }
                nArray3[n5] = n3 = n6.intValue();
            }
            nArray = nArray2;
        } else {
            NDArray nDArray = input.getData();
            if (nDArray == null) {
                throw new TypeCastException("null cannot be cast to non-null type io.kinference.ndarray.arrays.IntNDArray");
            }
            IntPointer pointer = IntTiledArray.pointer$default((IntTiledArray)((IntNDArray)nDArray).getArray(), (int)0, (int)1, null);
            int n = input.getData().getLinearSize();
            int[] nArray4 = new int[n];
            for (int i = 0; i < n; ++i) {
                int n7;
                int bl = i;
                int n8 = i;
                int[] nArray5 = nArray4;
                boolean bl2 = false;
                IntPointer this_$iv = pointer;
                boolean $i$f$getAndIncrement = false;
                int value$iv = this_$iv.getCurrentBlock()[this_$iv.getIndexInBlock()];
                IntPointer this_$iv$iv = this_$iv;
                boolean $i$f$increment = false;
                if (this_$iv$iv.getIndexInBlock() < this_$iv$iv.getArray().getBlockSize() - 1) {
                    IntPointer intPointer = this_$iv$iv;
                    int n9 = intPointer.getIndexInBlock();
                    intPointer.setIndexInBlock(n9 + 1);
                } else {
                    IntPointer this_$iv$iv$iv = this_$iv$iv;
                    boolean $i$f$blockIncrement = false;
                    if (this_$iv$iv$iv.getBlockNum() < this_$iv$iv$iv.getArray().getBlocksNum() - 1) {
                        IntPointer intPointer = this_$iv$iv$iv;
                        int n10 = intPointer.getBlockNum();
                        intPointer.setBlockNum(n10 + 1);
                        this_$iv$iv$iv.setIndexInBlock(0);
                        this_$iv$iv$iv.setCurrentBlock(this_$iv$iv$iv.getArray().getBlocks()[this_$iv$iv$iv.getBlockNum()]);
                    } else {
                        this_$iv$iv$iv.setIndexInBlock(this_$iv$iv$iv.getArray().getBlockSize());
                    }
                }
                nArray5[n8] = n7 = value$iv;
            }
            nArray = nArray4;
        }
        return nArray;
    }

    protected abstract void parseTempInputs(@NotNull Tensor var1, @NotNull Tensor var2, @Nullable Tensor var3, @Nullable Tensor var4, @Nullable Tensor var5, @Nullable Tensor var6);

    public LSTMBase(int hiddenSize, @NotNull List<String> activations, @NotNull String direction) {
        Intrinsics.checkParameterIsNotNull(activations, (String)"activations");
        Intrinsics.checkParameterIsNotNull((Object)direction, (String)"direction");
        super(hiddenSize, activations, direction);
    }
}

