/*
 * Decompiled with CFR 0.152.
 */
package io.kinference.ndarray.extensions;

import io.kinference.ndarray.Strides;
import io.kinference.ndarray.UtilsKt;
import io.kinference.ndarray.arrays.MutableNDArray;
import io.kinference.ndarray.arrays.MutableNumberNDArray;
import io.kinference.ndarray.arrays.NDArray;
import io.kinference.ndarray.arrays.NumberNDArray;
import io.kinference.ndarray.extensions.BroadcastingKt;
import io.kinference.ndarray.extensions.MatrixKt;
import io.kinference.ndarray.extensions.NDArrayExtensionsKt;
import kotlin.Metadata;
import kotlin.TypeCastException;
import kotlin.Unit;
import kotlin.collections.ArraysKt;
import kotlin.jvm.functions.Function3;
import kotlin.jvm.internal.Intrinsics;
import kotlin.reflect.KDeclarationContainer;
import org.jetbrains.annotations.NotNull;

@Metadata(mv={1, 1, 16}, bv={1, 0, 3}, k=2, d1={"\u0000@\n\u0000\n\u0002\u0018\u0002\n\u0000\n\u0002\u0010\b\n\u0002\b\u0003\n\u0002\u0010\u0006\n\u0000\n\u0002\u0018\u0002\n\u0002\b\u0007\n\u0002\u0010\u000b\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0000\u001ax\u0010\u0000\u001a\u00020\u00012\u0006\u0010\u0002\u001a\u00020\u00032\u0006\u0010\u0004\u001a\u00020\u00032\u0006\u0010\u0005\u001a\u00020\u00032\u0006\u0010\u0006\u001a\u00020\u00072\u0006\u0010\b\u001a\u00020\t2\u0006\u0010\n\u001a\u00020\t2\u0006\u0010\u000b\u001a\u00020\u00072\u0006\u0010\f\u001a\u00020\u00012\b\b\u0002\u0010\r\u001a\u00020\u00032\b\b\u0002\u0010\u000e\u001a\u00020\u00032\b\b\u0002\u0010\u000f\u001a\u00020\u00032\b\b\u0002\u0010\u0010\u001a\u00020\u00112\b\b\u0002\u0010\u0012\u001a\u00020\u0011\u001a\u0014\u0010\u0013\u001a\u00020\u0014*\u00020\t2\u0006\u0010\u0015\u001a\u00020\tH\u0002\u001a\u0015\u0010\u0016\u001a\u00020\u0017*\u00020\t2\u0006\u0010\u0015\u001a\u00020\tH\u0086\u0004\u001aA\u0010\u0016\u001a\u00020\u0017*\u00020\t2\u0006\u0010\u0015\u001a\u00020\t2\u0006\u0010\u0018\u001a\u00020\u00172#\u0010\u0019\u001a\u001f\u0012\u0004\u0012\u00020\t\u0012\u0004\u0012\u00020\t\u0012\u0004\u0012\u00020\u0017\u0012\u0004\u0012\u00020\u00170\u001a\u00a2\u0006\u0002\b\u001bH\u0002\u00a8\u0006\u001c"}, d2={"gemm", "Lio/kinference/ndarray/arrays/MutableNDArray;", "m", "", "n", "k", "alpha", "", "a", "Lio/kinference/ndarray/arrays/NumberNDArray;", "b", "beta", "c", "aOffset", "bOffset", "cOffset", "transposeA", "", "transposeB", "getOutputStrides", "Lio/kinference/ndarray/Strides;", "other", "matmul", "Lio/kinference/ndarray/arrays/MutableNumberNDArray;", "dest", "dotFunc", "Lkotlin/Function3;", "Lkotlin/ExtensionFunctionType;", "ndarray"})
public final class MatrixKt {
    @NotNull
    public static final MutableNDArray gemm(int m, int n, int k, double alpha, @NotNull NumberNDArray a, @NotNull NumberNDArray b, double beta, @NotNull MutableNDArray c, int aOffset, int bOffset, int cOffset, boolean transposeA, boolean transposeB) {
        Intrinsics.checkParameterIsNotNull((Object)a, (String)"a");
        Intrinsics.checkParameterIsNotNull((Object)b, (String)"b");
        Intrinsics.checkParameterIsNotNull((Object)c, (String)"c");
        int lda = transposeA ? m : k;
        int ldb = transposeB ? k : n;
        return a.gemm(m, n, k, alpha, lda, b, ldb, beta, c, n, aOffset, bOffset, cOffset, transposeA, transposeB);
    }

    public static /* synthetic */ MutableNDArray gemm$default(int n, int n2, int n3, double d, NumberNDArray numberNDArray, NumberNDArray numberNDArray2, double d2, MutableNDArray mutableNDArray, int n4, int n5, int n6, boolean bl, boolean bl2, int n7, Object object) {
        if ((n7 & 0x100) != 0) {
            n4 = 0;
        }
        if ((n7 & 0x200) != 0) {
            n5 = 0;
        }
        if ((n7 & 0x400) != 0) {
            n6 = 0;
        }
        if ((n7 & 0x800) != 0) {
            bl = false;
        }
        if ((n7 & 0x1000) != 0) {
            bl2 = false;
        }
        return MatrixKt.gemm(n, n2, n3, d, numberNDArray, numberNDArray2, d2, mutableNDArray, n4, n5, n6, bl, bl2);
    }

    private static final Strides getOutputStrides(@NotNull NumberNDArray $this$getOutputStrides, NumberNDArray other) {
        int[] outputMatrixShape = new int[]{$this$getOutputStrides.getShape()[NDArrayExtensionsKt.indexAxis($this$getOutputStrides, -2)], other.getShape()[NDArrayExtensionsKt.indexAxis(other, -1)]};
        int[] nArray = $this$getOutputStrides.getShape();
        int n = 0;
        int n2 = $this$getOutputStrides.getRank() - 2;
        boolean bl = false;
        int[] nArray2 = ArraysKt.copyOfRange((int[])nArray, (int)n, (int)n2);
        nArray = other.getShape();
        n = 0;
        n2 = other.getRank() - 2;
        int[] nArray3 = nArray2;
        bl = false;
        int[] nArray4 = ArraysKt.copyOfRange((int[])nArray, (int)n, (int)n2);
        int[] broadcastShape = BroadcastingKt.broadcastShape(nArray3, nArray4);
        int[] outputShape = new int[broadcastShape.length + 2];
        ArraysKt.copyInto$default((int[])broadcastShape, (int[])outputShape, (int)0, (int)0, (int)0, (int)14, null);
        ArraysKt.copyInto$default((int[])outputMatrixShape, (int[])outputShape, (int)broadcastShape.length, (int)0, (int)0, (int)12, null);
        return new Strides(outputShape);
    }

    @NotNull
    public static final MutableNumberNDArray matmul(@NotNull NumberNDArray $this$matmul, @NotNull NumberNDArray other) {
        Intrinsics.checkParameterIsNotNull((Object)$this$matmul, (String)"$this$matmul");
        Intrinsics.checkParameterIsNotNull((Object)other, (String)"other");
        Strides outputStrides = MatrixKt.getOutputStrides($this$matmul, other);
        MutableNumberNDArray outputArray = $this$matmul.allocateNDArray(outputStrides);
        return MatrixKt.matmul($this$matmul, other, outputArray, (Function3<? super NumberNDArray, ? super NumberNDArray, ? super MutableNumberNDArray, ? extends MutableNumberNDArray>)((Function3)matmul.1.INSTANCE));
    }

    private static final MutableNumberNDArray matmul(@NotNull NumberNDArray $this$matmul, NumberNDArray other, MutableNumberNDArray dest, Function3<? super NumberNDArray, ? super NumberNDArray, ? super MutableNumberNDArray, ? extends MutableNumberNDArray> dotFunc) {
        boolean bl = !NDArrayExtensionsKt.isScalar($this$matmul) && !NDArrayExtensionsKt.isScalar(other);
        boolean bl2 = false;
        boolean bl3 = false;
        if (!bl) {
            boolean bl4 = false;
            String string = "Matmul operation is not available for scalar tensors";
            throw (Throwable)new IllegalArgumentException(string.toString());
        }
        Function3<NDArray, NDArray, MutableNDArray, Unit> $fun$matmul$3 = new Function3<NDArray, NDArray, MutableNDArray, Unit>(dotFunc){
            final /* synthetic */ Function3 $dotFunc;

            public final void invoke(@NotNull NDArray left, @NotNull NDArray right, @NotNull MutableNDArray destination) {
                Intrinsics.checkParameterIsNotNull((Object)left, (String)"left");
                Intrinsics.checkParameterIsNotNull((Object)right, (String)"right");
                Intrinsics.checkParameterIsNotNull((Object)destination, (String)"destination");
                if (left.getRank() == 2) {
                    this.$dotFunc.invoke((Object)((NumberNDArray)left), (Object)((NumberNDArray)right), (Object)((MutableNumberNDArray)destination));
                } else {
                    BroadcastingKt.innerBroadcast(left, right, destination, (Function3<? super NDArray, ? super NDArray, ? super MutableNDArray, Unit>)((Function3)new Function3<NDArray, NDArray, MutableNDArray, Unit>(this){
                        final /* synthetic */ matmul.3 this$0;

                        public final void invoke(@NotNull NDArray p1, @NotNull NDArray p2, @NotNull MutableNDArray p3) {
                            Intrinsics.checkParameterIsNotNull((Object)p1, (String)"p1");
                            Intrinsics.checkParameterIsNotNull((Object)p2, (String)"p2");
                            Intrinsics.checkParameterIsNotNull((Object)p3, (String)"p3");
                            this.this$0.invoke(p1, p2, p3);
                        }

                        public final KDeclarationContainer getOwner() {
                            return null;
                        }

                        public final String getName() {
                            return "matmul";
                        }

                        public final String getSignature() {
                            return "invoke(Lio/kinference/ndarray/arrays/NDArray;Lio/kinference/ndarray/arrays/NDArray;Lio/kinference/ndarray/arrays/MutableNDArray;)V";
                        }
                        {
                            this.this$0 = var1_1;
                            super(3);
                        }
                    }));
                }
            }
            {
                this.$dotFunc = function3;
                super(3);
            }
        };
        if ($this$matmul.getRank() <= 2 && other.getRank() <= 2) {
            NDArray actualOther;
            NumberNDArray numberNDArray;
            if ($this$matmul.getRank() == 1) {
                NDArray nDArray = $this$matmul.reshapeView(UtilsKt.concat(1, $this$matmul.getShape()));
                if (nDArray == null) {
                    throw new TypeCastException("null cannot be cast to non-null type io.kinference.ndarray.arrays.NumberNDArray");
                }
                numberNDArray = (NumberNDArray)nDArray;
            } else {
                numberNDArray = $this$matmul;
            }
            NumberNDArray actualThis = numberNDArray;
            NDArray nDArray = actualOther = other.getRank() == 1 ? $this$matmul.reshapeView(UtilsKt.concat(other.getShape(), 1)) : (NDArray)other;
            if (nDArray == null) {
                throw new TypeCastException("null cannot be cast to non-null type io.kinference.ndarray.arrays.NumberNDArray");
            }
            return (MutableNumberNDArray)dotFunc.invoke((Object)actualThis, (Object)((NumberNDArray)nDArray), (Object)dest);
        }
        int[] leftWrapShape = BroadcastingKt.unsqueezeFirst($this$matmul.getShape(), dest.getRank());
        int[] rightWrapShape = BroadcastingKt.unsqueezeFirst(other.getShape(), dest.getRank());
        NDArray leftWrapped = $this$matmul.reshapeView(leftWrapShape);
        NDArray rightWrapped = other.reshapeView(rightWrapShape);
        $fun$matmul$3.invoke(leftWrapped, rightWrapped, (MutableNDArray)dest);
        return dest;
    }
}

