/*
 * Decompiled with CFR 0.152.
 */
package io.kinference.operators.math;

import io.kinference.attributes.Attribute;
import io.kinference.data.tensors.Tensor;
import io.kinference.data.tensors.TensorExtensionsKt;
import io.kinference.graph.Context;
import io.kinference.ndarray.Strides;
import io.kinference.ndarray.arrays.MutableNDArray;
import io.kinference.ndarray.arrays.MutableNumberNDArray;
import io.kinference.ndarray.arrays.NDArray;
import io.kinference.ndarray.arrays.NumberNDArray;
import io.kinference.ndarray.extensions.ArrayFactoriesKt;
import io.kinference.ndarray.extensions.BroadcastingKt;
import io.kinference.ndarray.extensions.MatrixKt;
import io.kinference.onnx.AttributeProto;
import io.kinference.onnx.TensorProto;
import io.kinference.operators.AttributeInfo;
import io.kinference.operators.IOInfo;
import io.kinference.operators.Operator;
import io.kinference.operators.OperatorInfo;
import io.kinference.operators.math.Gemm;
import io.kinference.primitives.types.DataType;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Set;
import kotlin.Metadata;
import kotlin.TypeCastException;
import kotlin.collections.CollectionsKt;
import kotlin.collections.SetsKt;
import kotlin.jvm.internal.DefaultConstructorMarker;
import kotlin.jvm.internal.Intrinsics;
import kotlin.jvm.internal.PropertyReference1;
import kotlin.jvm.internal.PropertyReference1Impl;
import kotlin.jvm.internal.Reflection;
import kotlin.reflect.KDeclarationContainer;
import kotlin.reflect.KProperty;
import org.jetbrains.annotations.NotNull;

@Metadata(mv={1, 1, 16}, bv={1, 0, 3}, k=1, d1={"\u0000@\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0000\n\u0002\u0010$\n\u0002\u0010\u000e\n\u0002\u0018\u0002\n\u0002\u0010\u0000\n\u0000\n\u0002\u0010 \n\u0002\b\u0003\n\u0002\u0010\u0006\n\u0002\b\b\n\u0002\u0010\u000b\n\u0002\b\b\n\u0002\u0018\u0002\n\u0002\b\u0002\u0018\u0000  2\u000e\u0012\u0004\u0012\u00020\u0002\u0012\u0004\u0012\u00020\u00020\u0001:\u0001 B;\u0012\u0018\u0010\u0003\u001a\u0014\u0012\u0004\u0012\u00020\u0005\u0012\n\u0012\b\u0012\u0004\u0012\u00020\u00070\u00060\u0004\u0012\f\u0010\b\u001a\b\u0012\u0004\u0012\u00020\u00050\t\u0012\f\u0010\n\u001a\b\u0012\u0004\u0012\u00020\u00050\t\u00a2\u0006\u0002\u0010\u000bJ(\u0010\u001d\u001a\n\u0012\u0006\u0012\u0004\u0018\u00010\u00020\t2\u0006\u0010\u001e\u001a\u00020\u001f2\u000e\u0010\b\u001a\n\u0012\u0006\u0012\u0004\u0018\u00010\u00020\tH\u0016R\u001b\u0010\f\u001a\u00020\r8BX\u0082\u0084\u0002\u00a2\u0006\f\n\u0004\b\u0010\u0010\u0011\u001a\u0004\b\u000e\u0010\u000fR\u001b\u0010\u0012\u001a\u00020\r8BX\u0082\u0084\u0002\u00a2\u0006\f\n\u0004\b\u0014\u0010\u0011\u001a\u0004\b\u0013\u0010\u000fR\u001b\u0010\u0015\u001a\u00020\u00168BX\u0082\u0084\u0002\u00a2\u0006\f\n\u0004\b\u0019\u0010\u0011\u001a\u0004\b\u0017\u0010\u0018R\u001b\u0010\u001a\u001a\u00020\u00168BX\u0082\u0084\u0002\u00a2\u0006\f\n\u0004\b\u001c\u0010\u0011\u001a\u0004\b\u001b\u0010\u0018\u00a8\u0006!"}, d2={"Lio/kinference/operators/math/Gemm;", "Lio/kinference/operators/Operator;", "Lio/kinference/data/tensors/Tensor;", "attributes", "", "", "Lio/kinference/attributes/Attribute;", "", "inputs", "", "outputs", "(Ljava/util/Map;Ljava/util/List;Ljava/util/List;)V", "alpha", "", "getAlpha", "()D", "alpha$delegate", "Lio/kinference/operators/Operator$AttributeValueDelegate;", "beta", "getBeta", "beta$delegate", "transA", "", "getTransA", "()Z", "transA$delegate", "transB", "getTransB", "transB$delegate", "apply", "context", "Lio/kinference/graph/Context;", "Companion", "inference"})
public final class Gemm
extends Operator<Tensor, Tensor> {
    static final /* synthetic */ KProperty[] $$delegatedProperties;
    private final Operator.AttributeValueDelegate alpha$delegate;
    private final Operator.AttributeValueDelegate beta$delegate;
    private final Operator.AttributeValueDelegate transA$delegate;
    private final Operator.AttributeValueDelegate transB$delegate;
    private static final Set<TensorProto.DataType> TYPE_CONSTRAINTS;
    private static final List<AttributeInfo> ATTRIBUTES_INFO;
    private static final List<IOInfo> INPUTS_INFO;
    private static final List<IOInfo> OUTPUTS_INFO;
    private static final OperatorInfo INFO;
    public static final Companion Companion;

    static {
        $$delegatedProperties = new KProperty[]{(KProperty)Reflection.property1((PropertyReference1)new PropertyReference1Impl((KDeclarationContainer)Reflection.getOrCreateKotlinClass(Gemm.class), "alpha", "getAlpha()D")), (KProperty)Reflection.property1((PropertyReference1)new PropertyReference1Impl((KDeclarationContainer)Reflection.getOrCreateKotlinClass(Gemm.class), "beta", "getBeta()D")), (KProperty)Reflection.property1((PropertyReference1)new PropertyReference1Impl((KDeclarationContainer)Reflection.getOrCreateKotlinClass(Gemm.class), "transA", "getTransA()Z")), (KProperty)Reflection.property1((PropertyReference1)new PropertyReference1Impl((KDeclarationContainer)Reflection.getOrCreateKotlinClass(Gemm.class), "transB", "getTransB()Z"))};
        Companion = new Companion(null);
        TYPE_CONSTRAINTS = SetsKt.setOf((Object[])new TensorProto.DataType[]{TensorProto.DataType.FLOAT16, TensorProto.DataType.FLOAT, TensorProto.DataType.DOUBLE, TensorProto.DataType.UINT32, TensorProto.DataType.UINT64, TensorProto.DataType.INT32, TensorProto.DataType.INT64, TensorProto.DataType.BFLOAT16});
        ATTRIBUTES_INFO = CollectionsKt.listOf((Object[])new AttributeInfo[]{new AttributeInfo("alpha", SetsKt.setOf((Object)((Object)AttributeProto.AttributeType.FLOAT)), false, 1.0), new AttributeInfo("beta", SetsKt.setOf((Object)((Object)AttributeProto.AttributeType.FLOAT)), false, 1.0), new AttributeInfo("transA", SetsKt.setOf((Object)((Object)AttributeProto.AttributeType.INT)), false, 0), new AttributeInfo("transB", SetsKt.setOf((Object)((Object)AttributeProto.AttributeType.INT)), false, 0)});
        INPUTS_INFO = CollectionsKt.listOf((Object[])new IOInfo[]{new IOInfo(0, TYPE_CONSTRAINTS, "A", false, null, false, null, 0, 240, null), new IOInfo(1, TYPE_CONSTRAINTS, "B", false, null, false, null, 0, 240, null), new IOInfo(2, TYPE_CONSTRAINTS, "C", true, null, false, null, 0, 240, null)});
        OUTPUTS_INFO = CollectionsKt.listOf((Object)new IOInfo(0, TYPE_CONSTRAINTS, "Y", false, null, false, null, 0, 240, null));
        INFO = new OperatorInfo("Gemm", (Collection<AttributeInfo>)ATTRIBUTES_INFO, INPUTS_INFO, OUTPUTS_INFO);
    }

    private final double getAlpha() {
        return ((Number)this.alpha$delegate.getValue(this, $$delegatedProperties[0])).doubleValue();
    }

    private final double getBeta() {
        return ((Number)this.beta$delegate.getValue(this, $$delegatedProperties[1])).doubleValue();
    }

    private final boolean getTransA() {
        return (Boolean)this.transA$delegate.getValue(this, $$delegatedProperties[2]);
    }

    private final boolean getTransB() {
        return (Boolean)this.transB$delegate.getValue(this, $$delegatedProperties[3]);
    }

    @Override
    @NotNull
    public List<Tensor> apply(@NotNull Context context, @NotNull List<Tensor> inputs) {
        Intrinsics.checkParameterIsNotNull((Object)context, (String)"context");
        Intrinsics.checkParameterIsNotNull(inputs, (String)"inputs");
        Tensor tensor = inputs.get(0);
        if (tensor == null) {
            Intrinsics.throwNpe();
        }
        NDArray nDArray = tensor.getData();
        if (nDArray == null) {
            throw new TypeCastException("null cannot be cast to non-null type io.kinference.ndarray.arrays.NumberNDArray");
        }
        NumberNDArray a = (NumberNDArray)nDArray;
        Tensor tensor2 = inputs.get(1);
        if (tensor2 == null) {
            Intrinsics.throwNpe();
        }
        NDArray nDArray2 = tensor2.getData();
        if (nDArray2 == null) {
            throw new TypeCastException("null cannot be cast to non-null type io.kinference.ndarray.arrays.NumberNDArray");
        }
        NumberNDArray b = (NumberNDArray)nDArray2;
        int m = !this.getTransA() ? a.getShape()[0] : a.getShape()[1];
        int n = !this.getTransB() ? b.getShape()[1] : b.getShape()[0];
        int k = !this.getTransA() ? a.getShape()[1] : a.getShape()[0];
        Tensor tensor3 = (Tensor)CollectionsKt.getOrNull(inputs, (int)2);
        MutableNDArray c = Gemm.Companion.getDest((NDArray)(tensor3 != null ? tensor3.getData() : null), a.getType(), new int[]{m, n});
        MatrixKt.gemm$default((int)m, (int)n, (int)k, (double)this.getAlpha(), (NumberNDArray)a, (NumberNDArray)b, (double)this.getBeta(), (MutableNDArray)c, (int)0, (int)0, (int)0, (boolean)this.getTransA(), (boolean)this.getTransB(), (int)1792, null);
        return CollectionsKt.listOf((Object)TensorExtensionsKt.asTensor$default((NDArray)c, null, 1, null));
    }

    public Gemm(@NotNull Map<String, Attribute<Object>> attributes, @NotNull List<String> inputs, @NotNull List<String> outputs) {
        Intrinsics.checkParameterIsNotNull(attributes, (String)"attributes");
        Intrinsics.checkParameterIsNotNull(inputs, (String)"inputs");
        Intrinsics.checkParameterIsNotNull(outputs, (String)"outputs");
        super(INFO, attributes, inputs, outputs);
        this.alpha$delegate = Operator.attribute$default(this, null, alpha.2.INSTANCE, 1, null);
        this.beta$delegate = Operator.attribute$default(this, null, beta.2.INSTANCE, 1, null);
        this.transA$delegate = Operator.attribute$default(this, null, transA.2.INSTANCE, 1, null);
        this.transB$delegate = Operator.attribute$default(this, null, transB.2.INSTANCE, 1, null);
    }

    @Metadata(mv={1, 1, 16}, bv={1, 0, 3}, k=1, d1={"\u0000F\n\u0002\u0018\u0002\n\u0002\u0010\u0000\n\u0002\b\u0002\n\u0002\u0010 \n\u0002\u0018\u0002\n\u0000\n\u0002\u0018\u0002\n\u0000\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0010\"\n\u0002\u0018\u0002\n\u0000\n\u0002\u0018\u0002\n\u0000\n\u0002\u0018\u0002\n\u0000\n\u0002\u0018\u0002\n\u0000\n\u0002\u0010\u0015\n\u0000\b\u0086\u0003\u0018\u00002\u00020\u0001B\u0007\b\u0002\u00a2\u0006\u0002\u0010\u0002J\"\u0010\u000e\u001a\u00020\u000f2\b\u0010\u0010\u001a\u0004\u0018\u00010\u00112\u0006\u0010\u0012\u001a\u00020\u00132\u0006\u0010\u0014\u001a\u00020\u0015H\u0002R\u0014\u0010\u0003\u001a\b\u0012\u0004\u0012\u00020\u00050\u0004X\u0082\u0004\u00a2\u0006\u0002\n\u0000R\u000e\u0010\u0006\u001a\u00020\u0007X\u0082\u0004\u00a2\u0006\u0002\n\u0000R\u0014\u0010\b\u001a\b\u0012\u0004\u0012\u00020\t0\u0004X\u0082\u0004\u00a2\u0006\u0002\n\u0000R\u0014\u0010\n\u001a\b\u0012\u0004\u0012\u00020\t0\u0004X\u0082\u0004\u00a2\u0006\u0002\n\u0000R\u0014\u0010\u000b\u001a\b\u0012\u0004\u0012\u00020\r0\fX\u0082\u0004\u00a2\u0006\u0002\n\u0000\u00a8\u0006\u0016"}, d2={"Lio/kinference/operators/math/Gemm$Companion;", "", "()V", "ATTRIBUTES_INFO", "", "Lio/kinference/operators/AttributeInfo;", "INFO", "Lio/kinference/operators/OperatorInfo;", "INPUTS_INFO", "Lio/kinference/operators/IOInfo;", "OUTPUTS_INFO", "TYPE_CONSTRAINTS", "", "Lio/kinference/onnx/TensorProto$DataType;", "getDest", "Lio/kinference/ndarray/arrays/MutableNDArray;", "array", "Lio/kinference/ndarray/arrays/NDArray;", "type", "Lio/kinference/primitives/types/DataType;", "targetShape", "", "inference"})
    public static final class Companion {
        /*
         * WARNING - void declaration
         */
        private final MutableNDArray getDest(NDArray array, DataType type, int[] targetShape) {
            int n;
            int targetBlockSize;
            if (array == null) {
                return ArrayFactoriesKt.allocateNDArray((DataType)type, (Strides)new Strides(targetShape));
            }
            int[] nArray = array.getShape();
            boolean bl = false;
            if (Arrays.equals(nArray, targetShape)) {
                return NDArray.DefaultImpls.toMutable$default((NDArray)array, null, (int)1, null);
            }
            MutableNDArray mutableNDArray = ArrayFactoriesKt.allocateNDArray((DataType)type, (Strides)new Strides(targetShape));
            if (mutableNDArray == null) {
                throw new TypeCastException("null cannot be cast to non-null type io.kinference.ndarray.arrays.MutableNumberNDArray");
            }
            MutableNumberNDArray dstArray = (MutableNumberNDArray)mutableNDArray;
            int[] unsqueezedShape = BroadcastingKt.unsqueezeFirst((int[])array.getShape(), (int)targetShape.length);
            if (targetShape[1] != unsqueezedShape[1] && unsqueezedShape[1] == 1) {
                targetBlockSize = targetShape[1];
                n = 0;
                int n2 = unsqueezedShape[0];
                while (n < n2) {
                    void i;
                    void dstOffsetBase = i * targetBlockSize;
                    dstArray.fillByArrayValue(array, (int)i, (int)dstOffsetBase, (int)(dstOffsetBase + targetBlockSize));
                    ++i;
                }
            } else {
                MutableNDArray.DefaultImpls.copyFrom$default((MutableNDArray)dstArray, (int)0, (NDArray)array, (int)0, (int)0, (int)12, null);
            }
            targetBlockSize = 1;
            n = targetShape[0];
            while (targetBlockSize < n) {
                void i;
                dstArray.copyFrom((int)(i * targetShape[1]), (NDArray)dstArray, 0, targetShape[1]);
                ++i;
            }
            return (MutableNDArray)dstArray;
        }

        private Companion() {
        }

        public /* synthetic */ Companion(DefaultConstructorMarker $constructor_marker) {
            this();
        }
    }
}

