/*
 * Decompiled with CFR 0.152.
 */
package io.kinference.operators.layer.attention;

import io.kinference.attributes.Attribute;
import io.kinference.data.tensors.Tensor;
import io.kinference.data.tensors.TensorExtensionsKt;
import io.kinference.graph.Context;
import io.kinference.ndarray.Strides;
import io.kinference.ndarray.arrays.FloatNDArray;
import io.kinference.ndarray.arrays.IntNDArray;
import io.kinference.ndarray.arrays.MutableFloatNDArray;
import io.kinference.ndarray.arrays.MutableNDArray;
import io.kinference.ndarray.arrays.MutableNumberNDArray;
import io.kinference.ndarray.arrays.NDArray;
import io.kinference.ndarray.arrays.NumberNDArray;
import io.kinference.ndarray.arrays.pointers.FloatPointer;
import io.kinference.ndarray.arrays.pointers.IntPointer;
import io.kinference.ndarray.arrays.tiled.FloatTiledArray;
import io.kinference.ndarray.extensions.ArrayFactoriesKt;
import io.kinference.onnx.AttributeProto;
import io.kinference.onnx.TensorProto;
import io.kinference.operators.AttributeInfo;
import io.kinference.operators.IOInfo;
import io.kinference.operators.Operator;
import io.kinference.operators.OperatorInfo;
import io.kinference.operators.activations.Softmax;
import io.kinference.operators.layer.attention.Attention;
import io.kinference.primitives.types.DataType;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Set;
import kotlin.Metadata;
import kotlin.Pair;
import kotlin.ResultKt;
import kotlin.TuplesKt;
import kotlin.TypeCastException;
import kotlin.Unit;
import kotlin.collections.CollectionsKt;
import kotlin.collections.SetsKt;
import kotlin.coroutines.Continuation;
import kotlin.coroutines.CoroutineContext;
import kotlin.coroutines.intrinsics.IntrinsicsKt;
import kotlin.jvm.functions.Function2;
import kotlin.jvm.internal.DefaultConstructorMarker;
import kotlin.jvm.internal.Intrinsics;
import kotlin.jvm.internal.PropertyReference1;
import kotlin.jvm.internal.PropertyReference1Impl;
import kotlin.jvm.internal.Reflection;
import kotlin.reflect.KDeclarationContainer;
import kotlin.reflect.KProperty;
import kotlinx.coroutines.BuildersKt;
import kotlinx.coroutines.CoroutineScope;
import kotlinx.coroutines.Dispatchers;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

@Metadata(mv={1, 1, 16}, bv={1, 0, 3}, k=1, d1={"\u0000@\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0000\n\u0002\u0010$\n\u0002\u0010\u000e\n\u0002\u0018\u0002\n\u0002\u0010\u0000\n\u0000\n\u0002\u0010 \n\u0002\b\u0003\n\u0002\u0010\b\n\u0002\b\u0005\n\u0002\u0010\u000b\n\u0002\b\u0005\n\u0002\u0018\u0002\n\u0002\b\u0002\u0018\u0000 \u001a2\u000e\u0012\u0004\u0012\u00020\u0002\u0012\u0004\u0012\u00020\u00020\u0001:\u0001\u001aB;\u0012\u0018\u0010\u0003\u001a\u0014\u0012\u0004\u0012\u00020\u0005\u0012\n\u0012\b\u0012\u0004\u0012\u00020\u00070\u00060\u0004\u0012\f\u0010\b\u001a\b\u0012\u0004\u0012\u00020\u00050\t\u0012\f\u0010\n\u001a\b\u0012\u0004\u0012\u00020\u00050\t\u00a2\u0006\u0002\u0010\u000bJ(\u0010\u0017\u001a\n\u0012\u0006\u0012\u0004\u0018\u00010\u00020\t2\u0006\u0010\u0018\u001a\u00020\u00192\u000e\u0010\b\u001a\n\u0012\u0006\u0012\u0004\u0018\u00010\u00020\tH\u0016R\u001b\u0010\f\u001a\u00020\r8BX\u0082\u0084\u0002\u00a2\u0006\f\n\u0004\b\u0010\u0010\u0011\u001a\u0004\b\u000e\u0010\u000fR\u001b\u0010\u0012\u001a\u00020\u00138BX\u0082\u0084\u0002\u00a2\u0006\f\n\u0004\b\u0016\u0010\u0011\u001a\u0004\b\u0014\u0010\u0015\u00a8\u0006\u001b"}, d2={"Lio/kinference/operators/layer/attention/Attention;", "Lio/kinference/operators/Operator;", "Lio/kinference/data/tensors/Tensor;", "attributes", "", "", "Lio/kinference/attributes/Attribute;", "", "inputs", "", "outputs", "(Ljava/util/Map;Ljava/util/List;Ljava/util/List;)V", "numHeads", "", "getNumHeads", "()I", "numHeads$delegate", "Lio/kinference/operators/Operator$AttributeValueDelegate;", "unidir", "", "getUnidir", "()Z", "unidir$delegate", "apply", "context", "Lio/kinference/graph/Context;", "Companion", "inference"})
public final class Attention
extends Operator<Tensor, Tensor> {
    static final /* synthetic */ KProperty[] $$delegatedProperties;
    private final Operator.AttributeValueDelegate numHeads$delegate;
    private final Operator.AttributeValueDelegate unidir$delegate;
    private static final Set<TensorProto.DataType> TYPE_CONSTRAINTS;
    private static final List<AttributeInfo> ATTRIBUTES_INFO;
    private static final List<IOInfo> INPUTS_INFO;
    private static final List<IOInfo> OUTPUTS_INFO;
    private static final OperatorInfo INFO;
    public static final Companion Companion;

    static {
        $$delegatedProperties = new KProperty[]{(KProperty)Reflection.property1((PropertyReference1)new PropertyReference1Impl((KDeclarationContainer)Reflection.getOrCreateKotlinClass(Attention.class), "numHeads", "getNumHeads()I")), (KProperty)Reflection.property1((PropertyReference1)new PropertyReference1Impl((KDeclarationContainer)Reflection.getOrCreateKotlinClass(Attention.class), "unidir", "getUnidir()Z"))};
        Companion = new Companion(null);
        TYPE_CONSTRAINTS = SetsKt.setOf((Object[])new TensorProto.DataType[]{TensorProto.DataType.FLOAT, TensorProto.DataType.FLOAT16});
        ATTRIBUTES_INFO = CollectionsKt.listOf((Object[])new AttributeInfo[]{new AttributeInfo("num_heads", SetsKt.setOf((Object)((Object)AttributeProto.AttributeType.INT)), true, null, 8, null), new AttributeInfo("unidirectional", SetsKt.setOf((Object)((Object)AttributeProto.AttributeType.INT)), false, 0)});
        INPUTS_INFO = CollectionsKt.listOf((Object[])new IOInfo[]{new IOInfo(0, TYPE_CONSTRAINTS, "input", false, null, false, null, 0, 240, null), new IOInfo(1, TYPE_CONSTRAINTS, "weight", false, null, false, null, 3, 112, null), new IOInfo(2, TYPE_CONSTRAINTS, "bias", false, null, false, null, 3, 112, null), new IOInfo(3, SetsKt.setOf((Object)((Object)TensorProto.DataType.INT32)), "mask_index", true, null, false, null, 0, 240, null), new IOInfo(4, TYPE_CONSTRAINTS, "past", true, null, false, null, 0, 240, null)});
        OUTPUTS_INFO = CollectionsKt.listOf((Object[])new IOInfo[]{new IOInfo(0, TYPE_CONSTRAINTS, "output", false, null, false, null, 0, 240, null), new IOInfo(1, TYPE_CONSTRAINTS, "present", true, null, false, null, 0, 240, null)});
        INFO = new OperatorInfo("Attention", (Collection<AttributeInfo>)ATTRIBUTES_INFO, INPUTS_INFO, OUTPUTS_INFO);
    }

    private final int getNumHeads() {
        return ((Number)this.numHeads$delegate.getValue(this, $$delegatedProperties[0])).intValue();
    }

    private final boolean getUnidir() {
        return (Boolean)this.unidir$delegate.getValue(this, $$delegatedProperties[1]);
    }

    /*
     * WARNING - void declaration
     */
    @Override
    @NotNull
    public List<Tensor> apply(@NotNull Context context, @NotNull List<Tensor> inputs) {
        void scores;
        void keys;
        void queries;
        void seqLen;
        void batchSize;
        NDArray nDArray;
        Object object;
        Intrinsics.checkParameterIsNotNull((Object)context, (String)"context");
        Intrinsics.checkParameterIsNotNull(inputs, (String)"inputs");
        Tensor tensor = inputs.get(0);
        if (tensor == null) {
            Intrinsics.throwNpe();
        }
        NDArray input = tensor.getData();
        Tensor tensor2 = inputs.get(1);
        if (tensor2 == null) {
            Intrinsics.throwNpe();
        }
        NDArray weights = tensor2.getData();
        Tensor tensor3 = inputs.get(2);
        if (tensor3 == null) {
            Intrinsics.throwNpe();
        }
        NDArray bias = tensor3.getData();
        List<Tensor> list = inputs;
        int n = 3;
        int n2 = 0;
        Tensor tensor4 = (Tensor)CollectionsKt.getOrNull(list, (int)n);
        IntNDArray maskIndices = (IntNDArray)(tensor4 != null ? tensor4.getData() : null);
        List<Tensor> list2 = inputs;
        n2 = 4;
        boolean bl = false;
        Tensor tensor5 = (Tensor)CollectionsKt.getOrNull(list2, (int)n2);
        NDArray past = tensor5 != null ? tensor5.getData() : null;
        Object object2 = object = input.getShape();
        boolean bl2 = false;
        int n3 = object2[0];
        object2 = object;
        bl2 = false;
        n2 = object2[1];
        object2 = object;
        bl2 = false;
        int hiddenSize = object2[2];
        NDArray nDArray2 = nDArray = Companion.initQueryKeyValue$inference(input, weights, bias, (int)batchSize, (int)seqLen, hiddenSize, this.getNumHeads());
        boolean bl3 = false;
        object = nDArray2[0];
        nDArray2 = nDArray;
        bl3 = false;
        object2 = nDArray2[1];
        nDArray2 = nDArray;
        bl3 = false;
        MutableNDArray values = nDArray2[2];
        Pair<NDArray, NDArray> pair = Companion.getScores$inference(this.getUnidir(), (NDArray)queries, (NDArray)keys, (NDArray)values, maskIndices, past, (int)batchSize, (int)seqLen, this.getNumHeads(), hiddenSize);
        nDArray = (NDArray)pair.component1();
        NDArray present = (NDArray)pair.component2();
        return CollectionsKt.listOf((Object[])new Tensor[]{TensorExtensionsKt.asTensor$default((NDArray)scores, null, 1, null), TensorExtensionsKt.asTensor$default(present, null, 1, null)});
    }

    public Attention(@NotNull Map<String, Attribute<Object>> attributes, @NotNull List<String> inputs, @NotNull List<String> outputs) {
        Intrinsics.checkParameterIsNotNull(attributes, (String)"attributes");
        Intrinsics.checkParameterIsNotNull(inputs, (String)"inputs");
        Intrinsics.checkParameterIsNotNull(outputs, (String)"outputs");
        super(INFO, attributes, inputs, outputs);
        this.numHeads$delegate = this.attribute("num_heads", numHeads.2.INSTANCE);
        this.unidir$delegate = this.attribute("unidirectional", unidir.2.INSTANCE);
    }

    @Metadata(mv={1, 1, 16}, bv={1, 0, 3}, k=1, d1={"\u0000p\n\u0002\u0018\u0002\n\u0002\u0010\u0000\n\u0002\b\u0002\n\u0002\u0010 \n\u0002\u0018\u0002\n\u0000\n\u0002\u0018\u0002\n\u0000\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0010\"\n\u0002\u0018\u0002\n\u0000\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0010\b\n\u0002\b\u0006\n\u0002\u0018\u0002\n\u0000\n\u0002\u0010\u0002\n\u0002\b\u0003\n\u0002\u0010\u000b\n\u0002\b\u0004\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0010\u0011\n\u0002\b\u000b\n\u0002\u0018\u0002\n\u0002\b\t\b\u0086\u0003\u0018\u00002\u00020\u0001B\u0007\b\u0002\u00a2\u0006\u0002\u0010\u0002J^\u0010\u000e\u001a\u000e\u0012\u0004\u0012\u00020\u0010\u0012\u0004\u0012\u00020\u00100\u000f2\u0006\u0010\u0011\u001a\u00020\u00102\u0006\u0010\u0012\u001a\u00020\u00102\u0006\u0010\u0013\u001a\u00020\u00142\u0006\u0010\u0015\u001a\u00020\u00142\u0006\u0010\u0016\u001a\u00020\u00142\u0006\u0010\u0017\u001a\u00020\u00142\u0006\u0010\u0018\u001a\u00020\u00142\b\u0010\u0019\u001a\u0004\u0018\u00010\u00102\u0006\u0010\u001a\u001a\u00020\u001bH\u0002J\"\u0010\u001c\u001a\u00020\u001d2\b\u0010\u0019\u001a\u0004\u0018\u00010\u00102\u0006\u0010\u001e\u001a\u00020\u00102\u0006\u0010\u001a\u001a\u00020\u001bH\u0002Jm\u0010\u001f\u001a\u000e\u0012\u0004\u0012\u00020\u0010\u0012\u0004\u0012\u00020\u00100\u000f2\u0006\u0010 \u001a\u00020!2\u0006\u0010\"\u001a\u00020\u00102\u0006\u0010#\u001a\u00020\u00102\u0006\u0010$\u001a\u00020\u00102\b\u0010%\u001a\u0004\u0018\u00010&2\b\u0010\u0019\u001a\u0004\u0018\u00010\u00102\u0006\u0010\u0013\u001a\u00020\u00142\u0006\u0010\u0015\u001a\u00020\u00142\u0006\u0010\u0017\u001a\u00020\u00142\u0006\u0010\u0018\u001a\u00020\u0014H\u0000\u00a2\u0006\u0002\b'JM\u0010(\u001a\b\u0012\u0004\u0012\u00020\u001b0)2\u0006\u0010*\u001a\u00020\u00102\u0006\u0010+\u001a\u00020\u00102\u0006\u0010,\u001a\u00020\u00102\u0006\u0010\u0013\u001a\u00020\u00142\u0006\u0010\u0015\u001a\u00020\u00142\u0006\u0010\u0018\u001a\u00020\u00142\u0006\u0010\u0017\u001a\u00020\u0014H\u0000\u00a2\u0006\u0004\b-\u0010.Jd\u0010/\u001a\u00020\u00102\u0006\u0010 \u001a\u00020!2\u0006\u00100\u001a\u00020\u00102\u0006\u00101\u001a\u00020\u00102\b\u00102\u001a\u0004\u0018\u00010&2\u0006\u0010\u0013\u001a\u00020\u00142\u0006\u0010\u0015\u001a\u00020\u00142\u0006\u0010\u0016\u001a\u00020\u00142\u0006\u00103\u001a\u00020\u00142\u0006\u0010\u0017\u001a\u00020\u00142\b\u0010\u0019\u001a\u0004\u0018\u00010\u00102\u0006\u0010\u001a\u001a\u00020\u001bH\u0002J.\u00104\u001a\u000205*\u0004\u0018\u00010&2\u0006\u0010 \u001a\u00020!2\u0006\u0010\u0013\u001a\u00020\u00142\u0006\u0010\u0015\u001a\u00020\u00142\u0006\u0010\u0016\u001a\u00020\u0014H\u0002JZ\u00106\u001a\u000e\u0012\u0004\u0012\u00020\u001b\u0012\u0004\u0012\u00020\u00140\u000f*\u00020\u001b2\b\u0010\u0019\u001a\u0004\u0018\u00010\u00102\u0006\u00107\u001a\u00020\u00102\u0006\u00108\u001a\u00020\u00142\u0006\u00109\u001a\u00020\u00142\u0006\u0010:\u001a\u00020\u00142\u0006\u0010;\u001a\u00020\u00142\u0006\u0010<\u001a\u00020\u00142\u0006\u0010=\u001a\u00020\u0014H\u0002R\u0014\u0010\u0003\u001a\b\u0012\u0004\u0012\u00020\u00050\u0004X\u0082\u0004\u00a2\u0006\u0002\n\u0000R\u000e\u0010\u0006\u001a\u00020\u0007X\u0082\u0004\u00a2\u0006\u0002\n\u0000R\u0014\u0010\b\u001a\b\u0012\u0004\u0012\u00020\t0\u0004X\u0082\u0004\u00a2\u0006\u0002\n\u0000R\u0014\u0010\n\u001a\b\u0012\u0004\u0012\u00020\t0\u0004X\u0082\u0004\u00a2\u0006\u0002\n\u0000R\u0014\u0010\u000b\u001a\b\u0012\u0004\u0012\u00020\r0\fX\u0082\u0004\u00a2\u0006\u0002\n\u0000\u00a8\u0006>"}, d2={"Lio/kinference/operators/layer/attention/Attention$Companion;", "", "()V", "ATTRIBUTES_INFO", "", "Lio/kinference/operators/AttributeInfo;", "INFO", "Lio/kinference/operators/OperatorInfo;", "INPUTS_INFO", "Lio/kinference/operators/IOInfo;", "OUTPUTS_INFO", "TYPE_CONSTRAINTS", "", "Lio/kinference/onnx/TensorProto$DataType;", "attentionScore", "Lkotlin/Pair;", "Lio/kinference/ndarray/arrays/NDArray;", "scores", "values", "batchSize", "", "seqLen", "pastSeqLen", "numHeads", "hiddenSize", "past", "present", "Lio/kinference/ndarray/arrays/MutableNDArray;", "concatStateChunk", "", "chunk", "getScores", "unidir", "", "q", "k", "v", "mask", "Lio/kinference/ndarray/arrays/IntNDArray;", "getScores$inference", "initQueryKeyValue", "", "input", "weights", "bias", "initQueryKeyValue$inference", "(Lio/kinference/ndarray/arrays/NDArray;Lio/kinference/ndarray/arrays/NDArray;Lio/kinference/ndarray/arrays/NDArray;IIII)[Lio/kinference/ndarray/arrays/MutableNDArray;", "normalizedScores", "queries", "keys", "maskIndices", "headSize", "maskFromIndices", "Lio/kinference/ndarray/arrays/FloatNDArray;", "updateState", "currentState", "pastBlockSize", "presentBlockSize", "i", "pastOffset", "presentOffset", "currentOffset", "inference"})
    public static final class Companion {
        /*
         * WARNING - void declaration
         */
        @NotNull
        public final MutableNDArray[] initQueryKeyValue$inference(@NotNull NDArray input, @NotNull NDArray weights, @NotNull NDArray bias, int batchSize, int seqLen, int hiddenSize, int numHeads2) {
            void $this$mapTo$iv$iv;
            NDArray[] nDArrayArray;
            Object object;
            int n;
            void $this$mapTo$iv$iv2;
            Intrinsics.checkParameterIsNotNull((Object)input, (String)"input");
            Intrinsics.checkParameterIsNotNull((Object)weights, (String)"weights");
            Intrinsics.checkParameterIsNotNull((Object)bias, (String)"bias");
            NumberNDArray cfr_ignored_0 = (NumberNDArray)input;
            int headSize = hiddenSize / numHeads2;
            NDArray[] $this$map$iv = weights.splitHorizontalByBlocks(3);
            boolean $i$f$map = false;
            NDArray[] nDArrayArray2 = $this$map$iv;
            NDArray[] destination$iv$iv = (NDArray[])new ArrayList($this$map$iv.length);
            boolean $i$f$mapTo = false;
            void var15_18 = $this$mapTo$iv$iv2;
            int n2 = ((void)var15_18).length;
            for (n = 0; n < n2; ++n) {
                void it;
                void item$iv$iv;
                void var19_26 = item$iv$iv = var15_18[n];
                object = destination$iv$iv;
                boolean bl = false;
                nDArrayArray = it.splitHorizontalByBlocks(numHeads2);
                object.add(nDArrayArray);
            }
            List qkvWeights = (List)destination$iv$iv;
            NDArray[] $this$map$iv2 = bias.splitHorizontalByBlocks(3);
            boolean $i$f$map2 = false;
            destination$iv$iv = $this$map$iv2;
            Collection destination$iv$iv2 = new ArrayList($this$map$iv2.length);
            int $i$f$mapTo2 = 0;
            void var16_21 = $this$mapTo$iv$iv;
            n = ((void)var16_21).length;
            for (int i = 0; i < n; ++i) {
                void it;
                void item$iv$iv;
                void bl = item$iv$iv = var16_21[i];
                object = destination$iv$iv2;
                boolean bl2 = false;
                nDArrayArray = it.splitHorizontalByBlocks(numHeads2);
                object.add(nDArrayArray);
            }
            List qkvBias = (List)destination$iv$iv2;
            int n3 = 3;
            MutableNDArray[] mutableNDArrayArray = new MutableNDArray[n3];
            int n4 = 0;
            while (n4 < n3) {
                $i$f$mapTo2 = n4;
                int n5 = n4++;
                object = mutableNDArrayArray;
                boolean bl = false;
                MutableNDArray mutableNDArray = ArrayFactoriesKt.allocateNDArray((DataType)input.getType(), (Strides)new Strides(new int[]{batchSize, numHeads2, seqLen, headSize}));
                object[n5] = mutableNDArray;
            }
            MutableNDArray[] qkv = mutableNDArrayArray;
            BuildersKt.runBlocking((CoroutineContext)((CoroutineContext)Dispatchers.getDefault()), (Function2)((Function2)new Function2<CoroutineScope, Continuation<? super Unit>, Object>(qkv, qkvWeights, qkvBias, batchSize, input, numHeads2, null){
                private CoroutineScope p$;
                int label;
                final /* synthetic */ MutableNDArray[] $qkv;
                final /* synthetic */ List $qkvWeights;
                final /* synthetic */ List $qkvBias;
                final /* synthetic */ int $batchSize;
                final /* synthetic */ NDArray $input;
                final /* synthetic */ int $numHeads;

                /*
                 * WARNING - void declaration
                 */
                @Nullable
                public final Object invokeSuspend(@NotNull Object $result) {
                    Object object = IntrinsicsKt.getCOROUTINE_SUSPENDED();
                    switch (this.label) {
                        case 0: {
                            ResultKt.throwOnFailure((Object)$result);
                            CoroutineScope $this$runBlocking = this.p$;
                            int n = 0;
                            int n2 = 3;
                            while (n < n2) {
                                void qkvIdx;
                                BuildersKt.launch$default((CoroutineScope)$this$runBlocking, null, null, (Function2)((Function2)new Function2<CoroutineScope, Continuation<? super Unit>, Object>(this, (int)qkvIdx, null){
                                    private CoroutineScope p$;
                                    int label;
                                    final /* synthetic */ initQueryKeyValue.1 this$0;
                                    final /* synthetic */ int $qkvIdx;

                                    /*
                                     * WARNING - void declaration
                                     */
                                    @Nullable
                                    public final Object invokeSuspend(@NotNull Object $result) {
                                        Object object = IntrinsicsKt.getCOROUTINE_SUSPENDED();
                                        switch (this.label) {
                                            case 0: {
                                                ResultKt.throwOnFailure((Object)$result);
                                                CoroutineScope $this$launch = this.p$;
                                                MutableNDArray output = this.this$0.$qkv[this.$qkvIdx];
                                                NDArray[] weights = (NDArray[])this.this$0.$qkvWeights.get(this.$qkvIdx);
                                                NDArray[] bias = (NDArray[])this.this$0.$qkvBias.get(this.$qkvIdx);
                                                int n = 0;
                                                int n2 = this.this$0.$batchSize;
                                                while (n < n2) {
                                                    void batchNum;
                                                    NDArray inputMatrix = this.this$0.$input.view(new int[]{batchNum});
                                                    int n3 = 0;
                                                    int n4 = this.this$0.$numHeads;
                                                    while (n3 < n4) {
                                                        void numHead;
                                                        MutableNDArray outputMatrix = output.viewMutable(new int[]{batchNum, numHead});
                                                        NDArray weightsMatrix = weights[numHead];
                                                        NDArray biasMatrix = bias[numHead];
                                                        NDArray nDArray = inputMatrix;
                                                        if (nDArray == null) {
                                                            throw new TypeCastException("null cannot be cast to non-null type io.kinference.ndarray.arrays.NumberNDArray");
                                                        }
                                                        NDArray nDArray2 = weightsMatrix;
                                                        if (nDArray2 == null) {
                                                            throw new TypeCastException("null cannot be cast to non-null type io.kinference.ndarray.arrays.NumberNDArray");
                                                        }
                                                        MutableNDArray mutableNDArray = outputMatrix;
                                                        if (mutableNDArray == null) {
                                                            throw new TypeCastException("null cannot be cast to non-null type io.kinference.ndarray.arrays.MutableNumberNDArray");
                                                        }
                                                        ((NumberNDArray)nDArray).dot((NumberNDArray)nDArray2, (MutableNumberNDArray)mutableNDArray);
                                                        ((MutableNumberNDArray)outputMatrix).plusAssign(biasMatrix);
                                                        ++numHead;
                                                    }
                                                    ++batchNum;
                                                }
                                                return Unit.INSTANCE;
                                            }
                                        }
                                        throw new IllegalStateException("call to 'resume' before 'invoke' with coroutine");
                                    }
                                    {
                                        this.this$0 = var1_1;
                                        this.$qkvIdx = n;
                                        super(2, continuation);
                                    }

                                    @NotNull
                                    public final Continuation<Unit> create(@Nullable Object value, @NotNull Continuation<?> completion) {
                                        Intrinsics.checkParameterIsNotNull(completion, (String)"completion");
                                        Function2<CoroutineScope, Continuation<? super Unit>, Object> function2 = new /* invalid duplicate definition of identical inner class */;
                                        CoroutineScope coroutineScope = function2.p$ = (CoroutineScope)value;
                                        return function2;
                                    }

                                    public final Object invoke(Object object, Object object2) {
                                        return (this.create(object, (Continuation)object2)).invokeSuspend(Unit.INSTANCE);
                                    }
                                }), (int)3, null);
                                ++qkvIdx;
                            }
                            return Unit.INSTANCE;
                        }
                    }
                    throw new IllegalStateException("call to 'resume' before 'invoke' with coroutine");
                }
                {
                    this.$qkv = mutableNDArrayArray;
                    this.$qkvWeights = list;
                    this.$qkvBias = list2;
                    this.$batchSize = n;
                    this.$input = nDArray;
                    this.$numHeads = n2;
                    super(2, continuation);
                }

                @NotNull
                public final Continuation<Unit> create(@Nullable Object value, @NotNull Continuation<?> completion) {
                    Intrinsics.checkParameterIsNotNull(completion, (String)"completion");
                    Function2<CoroutineScope, Continuation<? super Unit>, Object> function2 = new /* invalid duplicate definition of identical inner class */;
                    CoroutineScope coroutineScope = function2.p$ = (CoroutineScope)value;
                    return function2;
                }

                public final Object invoke(Object object, Object object2) {
                    return (this.create(object, (Continuation)object2)).invokeSuspend(Unit.INSTANCE);
                }
            }));
            return qkv;
        }

        /*
         * WARNING - void declaration
         */
        private final FloatNDArray maskFromIndices(@Nullable IntNDArray $this$maskFromIndices, boolean unidir2, int batchSize, int seqLen, int pastSeqLen) {
            int fullSeqLen = seqLen + pastSeqLen;
            int[] maskDataShape = new int[]{batchSize, seqLen, fullSeqLen};
            MutableNDArray mutableNDArray = ArrayFactoriesKt.allocateNDArray((DataType)DataType.FLOAT, (Strides)new Strides(maskDataShape));
            if (mutableNDArray == null) {
                throw new TypeCastException("null cannot be cast to non-null type io.kinference.ndarray.arrays.MutableFloatNDArray");
            }
            MutableFloatNDArray mask = (MutableFloatNDArray)mutableNDArray;
            int maskOffset = seqLen * fullSeqLen;
            boolean bl = false;
            int n = 0;
            int n2 = batchSize;
            for (n = 0; n < n2; ++n) {
                int n3;
                int n4;
                int n5;
                int n6;
                int n7;
                int i = n;
                boolean bl2 = false;
                if ($this$maskFromIndices != null) {
                    int $i$f$increment;
                    int $i$f$blockIncrement;
                    int end$iv;
                    int count$iv;
                    if ($this$maskFromIndices.getRank() == 2) {
                        float[] dstBlock$iv;
                        void $this$accept$iv;
                        FloatPointer maskPointer = mask.getArray().pointer(maskOffset * i);
                        IntPointer maskIndicesPointer = $this$maskFromIndices.getArray().pointer(i * fullSeqLen);
                        FloatPointer floatPointer = maskPointer;
                        count$iv = fullSeqLen;
                        boolean $i$f$accept = false;
                        void $this$isCompatibleBySize$iv$iv22 = $this$accept$iv;
                        boolean $i$f$isCompatibleBySize = false;
                        boolean $this$isCompatibleBySize$iv$iv22 = $this$isCompatibleBySize$iv$iv22.getArray().getSize() - $this$isCompatibleBySize$iv$iv22.getLinearIndex() >= count$iv && maskIndicesPointer.getArray().getSize() - maskIndicesPointer.getLinearIndex() >= count$iv;
                        $i$f$isCompatibleBySize = false;
                        boolean bl3 = false;
                        if (!$this$isCompatibleBySize$iv$iv22) {
                            boolean bl4 = false;
                            String string = "Pointers not compatible by available elements";
                            throw (Throwable)new IllegalArgumentException(string.toString());
                        }
                        void $this$isCompatibleWith$iv$iv = $this$accept$iv;
                        boolean $i$f$isCompatibleWith = false;
                        if ($this$isCompatibleWith$iv$iv.getIndexInBlock() == maskIndicesPointer.getIndexInBlock() && $this$isCompatibleWith$iv$iv.getArray().getBlockSize() == maskIndicesPointer.getArray().getBlockSize()) {
                            int dstOffset$iv;
                            for (end$iv = count$iv; end$iv > 0; end$iv -= dstBlock$iv.length - dstOffset$iv) {
                                dstBlock$iv = $this$accept$iv.getCurrentBlock();
                                dstOffset$iv = $this$accept$iv.getIndexInBlock();
                                void this_$iv$iv = $this$accept$iv;
                                $i$f$blockIncrement = 0;
                                if (this_$iv$iv.getBlockNum() < this_$iv$iv.getArray().getBlocksNum() - 1) {
                                    void v1 = this_$iv$iv;
                                    int n8 = v1.getBlockNum();
                                    v1.setBlockNum(n8 + 1);
                                    this_$iv$iv.setIndexInBlock(0);
                                    this_$iv$iv.setCurrentBlock(this_$iv$iv.getArray().getBlocks()[this_$iv$iv.getBlockNum()]);
                                } else {
                                    this_$iv$iv.setIndexInBlock(this_$iv$iv.getArray().getBlockSize());
                                }
                                int[] srcBlock$iv = maskIndicesPointer.getCurrentBlock();
                                IntPointer this_$iv$iv22 = maskIndicesPointer;
                                int $i$f$blockIncrement2 = 0;
                                if (this_$iv$iv22.getBlockNum() < this_$iv$iv22.getArray().getBlocksNum() - 1) {
                                    IntPointer intPointer = this_$iv$iv22;
                                    n7 = intPointer.getBlockNum();
                                    intPointer.setBlockNum(n7 + 1);
                                    this_$iv$iv22.setIndexInBlock(0);
                                    this_$iv$iv22.setCurrentBlock(this_$iv$iv22.getArray().getBlocks()[this_$iv$iv22.getBlockNum()]);
                                } else {
                                    this_$iv$iv22.setIndexInBlock(this_$iv$iv22.getArray().getBlockSize());
                                }
                                int this_$iv$iv22 = dstOffset$iv;
                                n7 = dstBlock$iv.length;
                                n6 = dstOffset$iv + end$iv;
                                n5 = 0;
                                $i$f$blockIncrement2 = Math.min(n7, n6);
                                while (this_$iv$iv22 < $i$f$blockIncrement2) {
                                    void src;
                                    float f;
                                    void index$iv;
                                    n4 = srcBlock$iv[index$iv];
                                    float f2 = dstBlock$iv[index$iv];
                                    n3 = index$iv;
                                    float[] fArray = dstBlock$iv;
                                    boolean bl5 = false;
                                    fArray[n3] = f = src > 0 ? 0.0f : -10000.0f;
                                    ++index$iv;
                                }
                            }
                        } else {
                            while (end$iv > 0) {
                                void src;
                                int n9;
                                dstBlock$iv = $this$accept$iv;
                                IntPointer this_$iv$iv = $this$accept$iv;
                                boolean $i$f$get = false;
                                float bl5 = this_$iv$iv.getCurrentBlock()[this_$iv$iv.getIndexInBlock()];
                                this_$iv$iv = maskIndicesPointer;
                                boolean $i$f$getAndIncrement = false;
                                int value$iv$iv = this_$iv$iv.getCurrentBlock()[this_$iv$iv.getIndexInBlock()];
                                IntPointer this_$iv$iv$iv3 = this_$iv$iv;
                                $i$f$increment = 0;
                                if (this_$iv$iv$iv3.getIndexInBlock() < this_$iv$iv$iv3.getArray().getBlockSize() - 1) {
                                    IntPointer intPointer = this_$iv$iv$iv3;
                                    int n10 = intPointer.getIndexInBlock();
                                    intPointer.setIndexInBlock(n10 + 1);
                                } else {
                                    IntPointer this_$iv$iv$iv$iv = this_$iv$iv$iv3;
                                    boolean $i$f$blockIncrement3 = false;
                                    if (this_$iv$iv$iv$iv.getBlockNum() < this_$iv$iv$iv$iv.getArray().getBlocksNum() - 1) {
                                        IntPointer intPointer = this_$iv$iv$iv$iv;
                                        int $noName_0 = intPointer.getBlockNum();
                                        intPointer.setBlockNum($noName_0 + 1);
                                        this_$iv$iv$iv$iv.setIndexInBlock(0);
                                        this_$iv$iv$iv$iv.setCurrentBlock(this_$iv$iv$iv$iv.getArray().getBlocks()[this_$iv$iv$iv$iv.getBlockNum()]);
                                    } else {
                                        this_$iv$iv$iv$iv.setIndexInBlock(this_$iv$iv$iv$iv.getArray().getBlockSize());
                                    }
                                }
                                int n11 = n9 = value$iv$iv;
                                float $noName_0 = bl5;
                                boolean $i$a$-accept-Attention$Companion$maskFromIndices$1$2 = false;
                                float value$iv$iv2 = src > 0 ? 0.0f : -10000.0f;
                                int $i$f$set = 0;
                                this_$iv$iv2.getCurrentBlock()[this_$iv$iv2.getIndexInBlock()] = value$iv$iv2;
                                void this_$iv$iv2 = $this$accept$iv;
                                boolean $i$f$increment2 = false;
                                if (this_$iv$iv2.getIndexInBlock() < this_$iv$iv2.getArray().getBlockSize() - 1) {
                                    void v5 = this_$iv$iv2;
                                    $i$f$set = v5.getIndexInBlock();
                                    v5.setIndexInBlock($i$f$set + 1);
                                } else {
                                    void this_$iv$iv$iv2 = this_$iv$iv2;
                                    $i$f$blockIncrement = 0;
                                    if (this_$iv$iv$iv2.getBlockNum() < this_$iv$iv$iv2.getArray().getBlocksNum() - 1) {
                                        void v6 = this_$iv$iv$iv2;
                                        int this_$iv$iv$iv3 = v6.getBlockNum();
                                        v6.setBlockNum(this_$iv$iv$iv3 + 1);
                                        this_$iv$iv$iv2.setIndexInBlock(0);
                                        this_$iv$iv$iv2.setCurrentBlock(this_$iv$iv$iv2.getArray().getBlocks()[this_$iv$iv$iv2.getBlockNum()]);
                                    } else {
                                        this_$iv$iv$iv2.setIndexInBlock(this_$iv$iv$iv2.getArray().getBlockSize());
                                    }
                                }
                                --end$iv;
                            }
                        }
                    } else {
                        FloatPointer $this$map$iv;
                        int offset$iv;
                        float[] block$iv;
                        void this_$iv;
                        FloatPointer maskPointer;
                        FloatPointer this_$iv2;
                        IntPointer maskIndicesPointer;
                        IntPointer $this$accept$iv = maskIndicesPointer = $this$maskFromIndices.getArray().pointer(i);
                        n3 = maskOffset * i;
                        FloatTiledArray floatTiledArray = mask.getArray();
                        boolean $i$f$get22 = false;
                        int n12 = this_$iv2.getCurrentBlock()[this_$iv2.getIndexInBlock()];
                        this_$iv2 = maskPointer = floatTiledArray.pointer(n3 + n12);
                        IntPointer $i$f$get22 = maskIndicesPointer;
                        int n13 = fullSeqLen;
                        int $i$f$get = 0;
                        n3 = this_$iv.getCurrentBlock()[this_$iv.getIndexInBlock()];
                        count$iv = n13 - n3;
                        boolean $i$f$map = false;
                        for (end$iv = count$iv; end$iv > 0; end$iv -= block$iv.length - offset$iv) {
                            int this_$iv$iv$iv3;
                            block$iv = $this$map$iv.getCurrentBlock();
                            offset$iv = $this$map$iv.getIndexInBlock();
                            void this_$iv$iv32 = $this$map$iv;
                            $i$f$blockIncrement = 0;
                            if (this_$iv$iv32.getBlockNum() < this_$iv$iv32.getArray().getBlocksNum() - 1) {
                                void v7 = this_$iv$iv32;
                                this_$iv$iv$iv3 = v7.getBlockNum();
                                v7.setBlockNum(this_$iv$iv$iv3 + 1);
                                this_$iv$iv32.setIndexInBlock(0);
                                this_$iv$iv32.setCurrentBlock(this_$iv$iv32.getArray().getBlocks()[this_$iv$iv32.getBlockNum()]);
                            } else {
                                this_$iv$iv32.setIndexInBlock(this_$iv$iv32.getArray().getBlockSize());
                            }
                            int this_$iv$iv32 = offset$iv;
                            this_$iv$iv$iv3 = block$iv.length;
                            $i$f$increment = offset$iv + end$iv;
                            n6 = 0;
                            $i$f$blockIncrement = Math.min(this_$iv$iv$iv3, $i$f$increment);
                            while (this_$iv$iv32 < $i$f$blockIncrement) {
                                float f;
                                void index$iv;
                                float $i$f$blockIncrement3 = block$iv[index$iv];
                                n3 = index$iv++;
                                float[] fArray = block$iv;
                                boolean bl6 = false;
                                fArray[n3] = f = -10000.0f;
                            }
                        }
                        if ($this$maskFromIndices.getRank() == 1 && $this$maskFromIndices.getShape()[0] == 2 * batchSize) {
                            float[] block$iv2;
                            maskIndicesPointer.setLinearIndex(i + batchSize);
                            maskPointer.setLinearIndex(maskOffset * i);
                            $this$map$iv = maskPointer;
                            IntPointer this_$iv222 = maskIndicesPointer;
                            $i$f$get = 0;
                            int this_$iv222 = this_$iv222.getCurrentBlock()[this_$iv222.getIndexInBlock()];
                            $i$f$get = fullSeqLen;
                            end$iv = 0;
                            count$iv = Math.min(this_$iv222, $i$f$get);
                            $i$f$map = false;
                            for (end$iv = count$iv; end$iv > 0; end$iv -= block$iv2.length - offset$iv) {
                                int this_$iv$iv$iv3;
                                block$iv2 = $this$map$iv.getCurrentBlock();
                                offset$iv = $this$map$iv.getIndexInBlock();
                                FloatPointer this_$iv$iv42 = $this$map$iv;
                                $i$f$blockIncrement = 0;
                                if (this_$iv$iv42.getBlockNum() < this_$iv$iv42.getArray().getBlocksNum() - 1) {
                                    FloatPointer floatPointer = this_$iv$iv42;
                                    this_$iv$iv$iv3 = floatPointer.getBlockNum();
                                    floatPointer.setBlockNum(this_$iv$iv$iv3 + 1);
                                    this_$iv$iv42.setIndexInBlock(0);
                                    this_$iv$iv42.setCurrentBlock(this_$iv$iv42.getArray().getBlocks()[this_$iv$iv42.getBlockNum()]);
                                } else {
                                    this_$iv$iv42.setIndexInBlock(this_$iv$iv42.getArray().getBlockSize());
                                }
                                int this_$iv$iv42 = offset$iv;
                                this_$iv$iv$iv3 = block$iv2.length;
                                $i$f$increment = offset$iv + end$iv;
                                n6 = 0;
                                $i$f$blockIncrement = Math.min(this_$iv$iv$iv3, $i$f$increment);
                                while (this_$iv$iv42 < $i$f$blockIncrement) {
                                    float f;
                                    void index$iv;
                                    float it = block$iv2[index$iv];
                                    n3 = index$iv++;
                                    float[] fArray = block$iv2;
                                    boolean bl7 = false;
                                    fArray[n3] = f = -10000.0f;
                                }
                            }
                        }
                    }
                }
                int maskIndicesPointer = 1;
                int maskPointer = seqLen;
                while (maskIndicesPointer < maskPointer) {
                    void seqIdx;
                    void start = seqIdx * fullSeqLen + i * maskOffset;
                    mask.copyFrom((int)start, (NDArray)mask, i * maskOffset, i * maskOffset + fullSeqLen);
                    ++seqIdx;
                }
                if (!unidir2) continue;
                FloatPointer maskPointer2 = FloatTiledArray.pointer$default((FloatTiledArray)mask.getArray(), (int)0, (int)1, null);
                maskPointer = 0;
                int n14 = seqLen - 1;
                while (maskPointer < n14) {
                    int offset$iv;
                    float[] block$iv;
                    void seqIdx;
                    int start = pastSeqLen + seqIdx + 1;
                    maskPointer2.setLinearIndex((int)(seqIdx * fullSeqLen + maskOffset * i + start));
                    FloatPointer $i$f$map = maskPointer2;
                    int count$iv = fullSeqLen - start;
                    boolean $i$f$map2 = false;
                    for (int end$iv = count$iv; end$iv > 0; end$iv -= block$iv.length - offset$iv) {
                        void $this$map$iv;
                        block$iv = $this$map$iv.getCurrentBlock();
                        offset$iv = $this$map$iv.getIndexInBlock();
                        void this_$iv$iv52 = $this$map$iv;
                        boolean $i$f$blockIncrement = false;
                        if (this_$iv$iv52.getBlockNum() < this_$iv$iv52.getArray().getBlocksNum() - 1) {
                            void v9 = this_$iv$iv52;
                            n6 = v9.getBlockNum();
                            v9.setBlockNum(n6 + 1);
                            this_$iv$iv52.setIndexInBlock(0);
                            this_$iv$iv52.setCurrentBlock(this_$iv$iv52.getArray().getBlocks()[this_$iv$iv52.getBlockNum()]);
                        } else {
                            this_$iv$iv52.setIndexInBlock(this_$iv$iv52.getArray().getBlockSize());
                        }
                        int this_$iv$iv52 = offset$iv;
                        n6 = block$iv.length;
                        n5 = offset$iv + end$iv;
                        n4 = 0;
                        n7 = Math.min(n6, n5);
                        while (this_$iv$iv52 < n7) {
                            void it;
                            void index$iv;
                            float $noName_0 = block$iv[index$iv];
                            n3 = index$iv++;
                            float[] fArray = block$iv;
                            boolean bl8 = false;
                            void var34_86 = it - 10000.0f;
                            fArray[n3] = var34_86;
                        }
                    }
                    ++seqIdx;
                }
            }
            return (FloatNDArray)mask;
        }

        private final Pair<MutableNDArray, Integer> updateState(@NotNull MutableNDArray $this$updateState, NDArray past, NDArray currentState, int pastBlockSize, int presentBlockSize, int i, int pastOffset, int presentOffset, int currentOffset) {
            int presentStart;
            int presentPos = presentStart = i * presentBlockSize + presentOffset;
            if (past != null) {
                int srcPast = i * pastBlockSize + pastOffset;
                $this$updateState.copyFrom(presentPos, past, srcPast, srcPast + pastBlockSize);
                presentPos += pastBlockSize;
            }
            $this$updateState.copyFrom(presentPos, currentState, currentOffset, currentOffset + presentBlockSize - pastBlockSize);
            return TuplesKt.to((Object)$this$updateState, (Object)presentStart);
        }

        private final void concatStateChunk(NDArray past, NDArray chunk, MutableNDArray present) {
            int additionalForChunkOffset = 0;
            if (past != null) {
                MutableNDArray.DefaultImpls.copyFrom$default((MutableNDArray)present, (int)0, (NDArray)past, (int)0, (int)0, (int)12, null);
                additionalForChunkOffset += past.getLinearSize();
            }
            MutableNDArray.DefaultImpls.copyFrom$default((MutableNDArray)present, (int)additionalForChunkOffset, (NDArray)chunk, (int)0, (int)0, (int)12, null);
        }

        private final NDArray normalizedScores(boolean unidir2, NDArray queries, NDArray keys, IntNDArray maskIndices, int batchSize, int seqLen, int pastSeqLen, int headSize, int numHeads2, NDArray past, MutableNDArray present) {
            int allSeqLen = pastSeqLen + seqLen;
            MutableNDArray scores = ArrayFactoriesKt.allocateNDArray((DataType)queries.getType(), (Strides)new Strides(new int[]{batchSize, numHeads2, seqLen, allSeqLen}));
            IntNDArray intNDArray = maskIndices;
            FloatNDArray maskData = intNDArray != null ? this.maskFromIndices(intNDArray, unidir2, batchSize, seqLen, pastSeqLen) : null;
            double d = headSize;
            double d2 = 1.0;
            boolean bl = false;
            double d3 = Math.sqrt(d);
            double alpha2 = d2 / d3;
            BuildersKt.runBlocking((CoroutineContext)((CoroutineContext)Dispatchers.getDefault()), (Function2)((Function2)new Function2<CoroutineScope, Continuation<? super Unit>, Object>(batchSize, numHeads2, queries, keys, past, present, scores, maskData, alpha2, null){
                private CoroutineScope p$;
                int label;
                final /* synthetic */ int $batchSize;
                final /* synthetic */ int $numHeads;
                final /* synthetic */ NDArray $queries;
                final /* synthetic */ NDArray $keys;
                final /* synthetic */ NDArray $past;
                final /* synthetic */ MutableNDArray $present;
                final /* synthetic */ MutableNDArray $scores;
                final /* synthetic */ FloatNDArray $maskData;
                final /* synthetic */ double $alpha;

                /*
                 * WARNING - void declaration
                 */
                @Nullable
                public final Object invokeSuspend(@NotNull Object $result) {
                    Object object = IntrinsicsKt.getCOROUTINE_SUSPENDED();
                    switch (this.label) {
                        case 0: {
                            ResultKt.throwOnFailure((Object)$result);
                            CoroutineScope $this$runBlocking = this.p$;
                            int n = 0;
                            int n2 = this.$batchSize;
                            while (n < n2) {
                                void batchNum;
                                int n3 = 0;
                                int n4 = this.$numHeads;
                                while (n3 < n4) {
                                    void numHead;
                                    BuildersKt.launch$default((CoroutineScope)$this$runBlocking, null, null, (Function2)((Function2)new Function2<CoroutineScope, Continuation<? super Unit>, Object>(this, (int)batchNum, (int)numHead, null){
                                        private CoroutineScope p$;
                                        int label;
                                        final /* synthetic */ normalizedScores.1 this$0;
                                        final /* synthetic */ int $batchNum;
                                        final /* synthetic */ int $numHead;

                                        @Nullable
                                        public final Object invokeSuspend(@NotNull Object $result) {
                                            Object object = IntrinsicsKt.getCOROUTINE_SUSPENDED();
                                            switch (this.label) {
                                                case 0: {
                                                    ResultKt.throwOnFailure((Object)$result);
                                                    CoroutineScope $this$launch = this.p$;
                                                    NDArray queryMatrix = this.this$0.$queries.view(new int[]{this.$batchNum, this.$numHead});
                                                    NDArray keyMatrix = this.this$0.$keys.view(new int[]{this.$batchNum, this.$numHead});
                                                    NDArray nDArray = this.this$0.$past;
                                                    NDArray pastMatrix = nDArray != null ? nDArray.view(new int[]{0, this.$batchNum, this.$numHead}) : null;
                                                    MutableNDArray presentMatrix = this.this$0.$present.viewMutable(new int[]{0, this.$batchNum, this.$numHead});
                                                    MutableNDArray scoresMatrix = this.this$0.$scores.viewMutable(new int[]{this.$batchNum, this.$numHead});
                                                    FloatNDArray floatNDArray = this.this$0.$maskData;
                                                    FloatNDArray maskVector = floatNDArray != null ? floatNDArray.view(new int[]{this.$batchNum}) : null;
                                                    io.kinference.operators.layer.attention.Attention$Companion.access$concatStateChunk(Attention.Companion, pastMatrix, keyMatrix, presentMatrix);
                                                    NDArray nDArray2 = queryMatrix;
                                                    if (nDArray2 == null) {
                                                        throw new TypeCastException("null cannot be cast to non-null type io.kinference.ndarray.arrays.NumberNDArray");
                                                    }
                                                    MutableNDArray mutableNDArray = presentMatrix;
                                                    if (mutableNDArray == null) {
                                                        throw new TypeCastException("null cannot be cast to non-null type io.kinference.ndarray.arrays.NumberNDArray");
                                                    }
                                                    MutableNDArray mutableNDArray2 = scoresMatrix;
                                                    if (mutableNDArray2 == null) {
                                                        throw new TypeCastException("null cannot be cast to non-null type io.kinference.ndarray.arrays.MutableNumberNDArray");
                                                    }
                                                    ((NumberNDArray)nDArray2).dotTransposedWithAlpha(this.this$0.$alpha, (NumberNDArray)mutableNDArray, (MutableNumberNDArray)mutableNDArray2);
                                                    if (maskVector != null) {
                                                        ((MutableNumberNDArray)scoresMatrix).plusAssign((NDArray)maskVector);
                                                    }
                                                    return Unit.INSTANCE;
                                                }
                                            }
                                            throw new IllegalStateException("call to 'resume' before 'invoke' with coroutine");
                                        }
                                        {
                                            this.this$0 = var1_1;
                                            this.$batchNum = n;
                                            this.$numHead = n2;
                                            super(2, continuation);
                                        }

                                        @NotNull
                                        public final Continuation<Unit> create(@Nullable Object value, @NotNull Continuation<?> completion) {
                                            Intrinsics.checkParameterIsNotNull(completion, (String)"completion");
                                            Function2<CoroutineScope, Continuation<? super Unit>, Object> function2 = new /* invalid duplicate definition of identical inner class */;
                                            CoroutineScope coroutineScope = function2.p$ = (CoroutineScope)value;
                                            return function2;
                                        }

                                        public final Object invoke(Object object, Object object2) {
                                            return (this.create(object, (Continuation)object2)).invokeSuspend(Unit.INSTANCE);
                                        }
                                    }), (int)3, null);
                                    ++numHead;
                                }
                                ++batchNum;
                            }
                            return Unit.INSTANCE;
                        }
                    }
                    throw new IllegalStateException("call to 'resume' before 'invoke' with coroutine");
                }
                {
                    this.$batchSize = n;
                    this.$numHeads = n2;
                    this.$queries = nDArray;
                    this.$keys = nDArray2;
                    this.$past = nDArray3;
                    this.$present = mutableNDArray;
                    this.$scores = mutableNDArray2;
                    this.$maskData = floatNDArray;
                    this.$alpha = d;
                    super(2, continuation);
                }

                @NotNull
                public final Continuation<Unit> create(@Nullable Object value, @NotNull Continuation<?> completion) {
                    Intrinsics.checkParameterIsNotNull(completion, (String)"completion");
                    Function2<CoroutineScope, Continuation<? super Unit>, Object> function2 = new /* invalid duplicate definition of identical inner class */;
                    CoroutineScope coroutineScope = function2.p$ = (CoroutineScope)value;
                    return function2;
                }

                public final Object invoke(Object object, Object object2) {
                    return (this.create(object, (Continuation)object2)).invokeSuspend(Unit.INSTANCE);
                }
            }));
            return (NDArray)Softmax.Companion.softmax$default(Softmax.Companion, (NDArray)scores, -1, null, 4, null);
        }

        private final Pair<NDArray, NDArray> attentionScore(NDArray scores, NDArray values, int batchSize, int seqLen, int pastSeqLen, int numHeads2, int hiddenSize, NDArray past, MutableNDArray present) {
            int headSize = hiddenSize / numHeads2;
            MutableNDArray output = ArrayFactoriesKt.allocateNDArray((DataType)scores.getType(), (Strides)new Strides(new int[]{batchSize, numHeads2, seqLen, headSize}));
            BuildersKt.runBlocking((CoroutineContext)((CoroutineContext)Dispatchers.getDefault()), (Function2)((Function2)new Function2<CoroutineScope, Continuation<? super Unit>, Object>(batchSize, numHeads2, scores, output, values, past, present, null){
                private CoroutineScope p$;
                int label;
                final /* synthetic */ int $batchSize;
                final /* synthetic */ int $numHeads;
                final /* synthetic */ NDArray $scores;
                final /* synthetic */ MutableNDArray $output;
                final /* synthetic */ NDArray $values;
                final /* synthetic */ NDArray $past;
                final /* synthetic */ MutableNDArray $present;

                /*
                 * WARNING - void declaration
                 */
                @Nullable
                public final Object invokeSuspend(@NotNull Object $result) {
                    Object object = IntrinsicsKt.getCOROUTINE_SUSPENDED();
                    switch (this.label) {
                        case 0: {
                            ResultKt.throwOnFailure((Object)$result);
                            CoroutineScope $this$runBlocking = this.p$;
                            int n = 0;
                            int n2 = this.$batchSize;
                            while (n < n2) {
                                void batchNum;
                                int n3 = 0;
                                int n4 = this.$numHeads;
                                while (n3 < n4) {
                                    void numHead;
                                    BuildersKt.launch$default((CoroutineScope)$this$runBlocking, null, null, (Function2)((Function2)new Function2<CoroutineScope, Continuation<? super Unit>, Object>(this, (int)batchNum, (int)numHead, null){
                                        private CoroutineScope p$;
                                        int label;
                                        final /* synthetic */ attentionScore.1 this$0;
                                        final /* synthetic */ int $batchNum;
                                        final /* synthetic */ int $numHead;

                                        @Nullable
                                        public final Object invokeSuspend(@NotNull Object $result) {
                                            Object object = IntrinsicsKt.getCOROUTINE_SUSPENDED();
                                            switch (this.label) {
                                                case 0: {
                                                    ResultKt.throwOnFailure((Object)$result);
                                                    CoroutineScope $this$launch = this.p$;
                                                    NDArray tempScores = this.this$0.$scores.view(new int[]{this.$batchNum, this.$numHead});
                                                    MutableNDArray tempOutput = this.this$0.$output.viewMutable(new int[]{this.$batchNum, this.$numHead});
                                                    NDArray tempValues = this.this$0.$values.view(new int[]{this.$batchNum, this.$numHead});
                                                    NDArray nDArray = this.this$0.$past;
                                                    NDArray tempPast = nDArray != null ? nDArray.view(new int[]{1, this.$batchNum, this.$numHead}) : null;
                                                    MutableNDArray tempPresent = this.this$0.$present.viewMutable(new int[]{1, this.$batchNum, this.$numHead});
                                                    io.kinference.operators.layer.attention.Attention$Companion.access$concatStateChunk(Attention.Companion, tempPast, tempValues, tempPresent);
                                                    NDArray nDArray2 = tempScores;
                                                    if (nDArray2 == null) {
                                                        throw new TypeCastException("null cannot be cast to non-null type io.kinference.ndarray.arrays.NumberNDArray");
                                                    }
                                                    MutableNDArray mutableNDArray = tempPresent;
                                                    if (mutableNDArray == null) {
                                                        throw new TypeCastException("null cannot be cast to non-null type io.kinference.ndarray.arrays.NumberNDArray");
                                                    }
                                                    MutableNDArray mutableNDArray2 = tempOutput;
                                                    if (mutableNDArray2 == null) {
                                                        throw new TypeCastException("null cannot be cast to non-null type io.kinference.ndarray.arrays.MutableNumberNDArray");
                                                    }
                                                    ((NumberNDArray)nDArray2).dot((NumberNDArray)mutableNDArray, (MutableNumberNDArray)mutableNDArray2);
                                                    return Unit.INSTANCE;
                                                }
                                            }
                                            throw new IllegalStateException("call to 'resume' before 'invoke' with coroutine");
                                        }
                                        {
                                            this.this$0 = var1_1;
                                            this.$batchNum = n;
                                            this.$numHead = n2;
                                            super(2, continuation);
                                        }

                                        @NotNull
                                        public final Continuation<Unit> create(@Nullable Object value, @NotNull Continuation<?> completion) {
                                            Intrinsics.checkParameterIsNotNull(completion, (String)"completion");
                                            Function2<CoroutineScope, Continuation<? super Unit>, Object> function2 = new /* invalid duplicate definition of identical inner class */;
                                            CoroutineScope coroutineScope = function2.p$ = (CoroutineScope)value;
                                            return function2;
                                        }

                                        public final Object invoke(Object object, Object object2) {
                                            return (this.create(object, (Continuation)object2)).invokeSuspend(Unit.INSTANCE);
                                        }
                                    }), (int)3, null);
                                    ++numHead;
                                }
                                ++batchNum;
                            }
                            return Unit.INSTANCE;
                        }
                    }
                    throw new IllegalStateException("call to 'resume' before 'invoke' with coroutine");
                }
                {
                    this.$batchSize = n;
                    this.$numHeads = n2;
                    this.$scores = nDArray;
                    this.$output = mutableNDArray;
                    this.$values = nDArray2;
                    this.$past = nDArray3;
                    this.$present = mutableNDArray2;
                    super(2, continuation);
                }

                @NotNull
                public final Continuation<Unit> create(@Nullable Object value, @NotNull Continuation<?> completion) {
                    Intrinsics.checkParameterIsNotNull(completion, (String)"completion");
                    Function2<CoroutineScope, Continuation<? super Unit>, Object> function2 = new /* invalid duplicate definition of identical inner class */;
                    CoroutineScope coroutineScope = function2.p$ = (CoroutineScope)value;
                    return function2;
                }

                public final Object invoke(Object object, Object object2) {
                    return (this.create(object, (Continuation)object2)).invokeSuspend(Unit.INSTANCE);
                }
            }));
            output.transpose(new int[]{0, 2, 1, 3});
            return TuplesKt.to((Object)output.reshapeView(new int[]{batchSize, seqLen, hiddenSize}), (Object)present);
        }

        @NotNull
        public final Pair<NDArray, NDArray> getScores$inference(boolean unidir2, @NotNull NDArray q, @NotNull NDArray k, @NotNull NDArray v, @Nullable IntNDArray mask, @Nullable NDArray past, int batchSize, int seqLen, int numHeads2, int hiddenSize) {
            Intrinsics.checkParameterIsNotNull((Object)q, (String)"q");
            Intrinsics.checkParameterIsNotNull((Object)k, (String)"k");
            Intrinsics.checkParameterIsNotNull((Object)v, (String)"v");
            int pastSeqLen = 0;
            int headSize = hiddenSize / numHeads2;
            int[] presentDims = new int[]{2, batchSize, numHeads2, seqLen, headSize};
            if (past != null) {
                pastSeqLen = past.getShape()[3];
                presentDims[3] = presentDims[3] + pastSeqLen;
            }
            MutableNDArray present = ArrayFactoriesKt.allocateNDArray((DataType)q.getType(), (Strides)new Strides(presentDims));
            NDArray scores = this.normalizedScores(unidir2, q, k, mask, batchSize, seqLen, pastSeqLen, headSize, numHeads2, past, present);
            return this.attentionScore(scores, v, batchSize, seqLen, pastSeqLen, numHeads2, hiddenSize, past, present);
        }

        private Companion() {
        }

        public static final /* synthetic */ void access$concatStateChunk(Companion $this, NDArray past, NDArray chunk, MutableNDArray present) {
            $this.concatStateChunk(past, chunk, present);
        }

        public /* synthetic */ Companion(DefaultConstructorMarker $constructor_marker) {
            this();
        }
    }
}

