/*
 * Decompiled with CFR 0.152.
 */
package io.kinference.ndarray.extensions;

import io.kinference.ndarray.Strides;
import io.kinference.ndarray.arrays.IntNDArray;
import io.kinference.ndarray.arrays.LongNDArray;
import io.kinference.ndarray.arrays.MutableNDArray;
import io.kinference.ndarray.arrays.NDArray;
import io.kinference.ndarray.arrays.pointers.IntPointer;
import io.kinference.ndarray.arrays.pointers.LongPointer;
import io.kinference.ndarray.arrays.tiled.IntTiledArray;
import io.kinference.ndarray.arrays.tiled.LongTiledArray;
import io.kinference.ndarray.extensions.ArrayFactoriesKt;
import io.kinference.ndarray.extensions.GatherKt$WhenMappings;
import io.kinference.ndarray.extensions.NDArrayExtensionsKt;
import io.kinference.primitives.types.DataType;
import kotlin.Metadata;
import kotlin.collections.ArraysKt;
import kotlin.jvm.internal.Intrinsics;
import kotlin.ranges.IntRange;
import kotlin.ranges.RangesKt;
import org.jetbrains.annotations.NotNull;

@Metadata(mv={1, 1, 16}, bv={1, 0, 3}, k=2, d1={"\u0000\"\n\u0000\n\u0002\u0018\u0002\n\u0000\n\u0002\u0010\b\n\u0000\n\u0002\u0018\u0002\n\u0000\n\u0002\u0010\u0015\n\u0000\n\u0002\u0018\u0002\n\u0002\b\u0005\u001a&\u0010\u0000\u001a\u00020\u00012\u0006\u0010\u0002\u001a\u00020\u00032\u0006\u0010\u0004\u001a\u00020\u00052\u0006\u0010\u0006\u001a\u00020\u00072\u0006\u0010\b\u001a\u00020\t\u001a\u001e\u0010\n\u001a\u00020\u0003*\u00020\u00052\b\b\u0002\u0010\u000b\u001a\u00020\u00032\b\b\u0002\u0010\f\u001a\u00020\u0003\u001a\u001c\u0010\r\u001a\u00020\u0005*\u00020\u00052\u0006\u0010\u0004\u001a\u00020\u00052\b\b\u0002\u0010\u0002\u001a\u00020\u0003\u00a8\u0006\u000e"}, d2={"createGatherDstArray", "Lio/kinference/ndarray/arrays/MutableNDArray;", "axis", "", "indices", "Lio/kinference/ndarray/arrays/NDArray;", "shape", "", "type", "Lio/kinference/primitives/types/DataType;", "computeBlockSize", "fromDim", "toDim", "gather", "ndarray"})
public final class GatherKt {
    /*
     * WARNING - void declaration
     */
    public static final int computeBlockSize(@NotNull NDArray $this$computeBlockSize, int fromDim, int toDim) {
        void $this$fold$iv;
        Intrinsics.checkParameterIsNotNull((Object)$this$computeBlockSize, (String)"$this$computeBlockSize");
        int[] nArray = ArraysKt.sliceArray((int[])$this$computeBlockSize.getShape(), (IntRange)RangesKt.until((int)fromDim, (int)toDim));
        int initial$iv = 1;
        boolean $i$f$fold = false;
        int accumulator$iv = initial$iv;
        void var7_7 = $this$fold$iv;
        int n = ((void)var7_7).length;
        for (int i = 0; i < n; ++i) {
            void p2;
            void element$iv;
            void var11_11 = element$iv = var7_7[i];
            int p1 = accumulator$iv;
            boolean bl = false;
            accumulator$iv = p1 * p2;
        }
        return accumulator$iv;
    }

    public static /* synthetic */ int computeBlockSize$default(NDArray nDArray, int n, int n2, int n3, Object object) {
        if ((n3 & 1) != 0) {
            n = 0;
        }
        if ((n3 & 2) != 0) {
            n2 = nDArray.getShape().length;
        }
        return GatherKt.computeBlockSize(nDArray, n, n2);
    }

    @NotNull
    public static final MutableNDArray createGatherDstArray(int axis, @NotNull NDArray indices, @NotNull int[] shape, @NotNull DataType type) {
        Intrinsics.checkParameterIsNotNull((Object)indices, (String)"indices");
        Intrinsics.checkParameterIsNotNull((Object)shape, (String)"shape");
        Intrinsics.checkParameterIsNotNull((Object)type, (String)"type");
        int[] newShape = new int[shape.length + indices.getRank() - 1];
        ArraysKt.copyInto((int[])shape, (int[])newShape, (int)0, (int)0, (int)axis);
        ArraysKt.copyInto$default((int[])indices.getShape(), (int[])newShape, (int)axis, (int)0, (int)0, (int)12, null);
        ArraysKt.copyInto$default((int[])shape, (int[])newShape, (int)(axis + indices.getRank()), (int)(axis + 1), (int)0, (int)8, null);
        Strides newStrides = new Strides(newShape);
        return ArrayFactoriesKt.allocateNDArray(type, newStrides);
    }

    /*
     * WARNING - void declaration
     */
    @NotNull
    public static final NDArray gather(@NotNull NDArray $this$gather, @NotNull NDArray indices, int axis) {
        Intrinsics.checkParameterIsNotNull((Object)$this$gather, (String)"$this$gather");
        Intrinsics.checkParameterIsNotNull((Object)indices, (String)"indices");
        int actualAxis = NDArrayExtensionsKt.indexAxis($this$gather, axis);
        MutableNDArray dst = GatherKt.createGatherDstArray(actualAxis, indices, $this$gather.getShape(), $this$gather.getType());
        int block = GatherKt.computeBlockSize$default($this$gather, actualAxis + 1, 0, 2, null);
        int dataBatch = GatherKt.computeBlockSize$default($this$gather, actualAxis, 0, 2, null);
        int indicesSize = indices.getStrides().getLinearSize();
        int gatheredBatch = indicesSize * block;
        int numBlocks = GatherKt.computeBlockSize$default($this$gather, 0, actualAxis, 1, null);
        switch (GatherKt$WhenMappings.$EnumSwitchMapping$0[indices.getType().ordinal()]) {
            case 1: {
                LongNDArray cfr_ignored_0 = (LongNDArray)indices;
                LongPointer pointer = LongTiledArray.pointer$default(((LongNDArray)indices).getArray(), 0, 1, null);
                int n = 0;
                int n2 = numBlocks;
                while (n < n2) {
                    void numBatch;
                    int offset$iv;
                    long[] block$iv;
                    int index = 0;
                    LongPointer $this$forEach$iv = pointer;
                    boolean $i$f$forEach = false;
                    for (int end$iv = indicesSize; end$iv > 0; end$iv -= block$iv.length - offset$iv) {
                        int n3;
                        block$iv = $this$forEach$iv.getCurrentBlock();
                        offset$iv = $this$forEach$iv.getIndexInBlock();
                        LongPointer this_$iv$iv22 = $this$forEach$iv;
                        int $i$f$blockIncrement = 0;
                        if (this_$iv$iv22.getBlockNum() < this_$iv$iv22.getArray().getBlocksNum() - 1) {
                            LongPointer longPointer = this_$iv$iv22;
                            n3 = longPointer.getBlockNum();
                            longPointer.setBlockNum(n3 + 1);
                            this_$iv$iv22.setIndexInBlock(0);
                            this_$iv$iv22.setCurrentBlock(this_$iv$iv22.getArray().getBlocks()[this_$iv$iv22.getBlockNum()]);
                        } else {
                            this_$iv$iv22.setIndexInBlock(this_$iv$iv22.getArray().getBlockSize());
                        }
                        int this_$iv$iv22 = offset$iv;
                        n3 = block$iv.length;
                        int n4 = offset$iv + end$iv;
                        boolean bl = false;
                        $i$f$blockIncrement = Math.min(n3, n4);
                        while (this_$iv$iv22 < $i$f$blockIncrement) {
                            void index$iv;
                            long it = block$iv[index$iv];
                            boolean bl2 = false;
                            int idx = (int)(it < 0L ? it + (long)$this$gather.getShape()[actualAxis] : it);
                            void srcOffset = numBatch * dataBatch + idx * block;
                            int n5 = index;
                            index = n5 + 1;
                            void dstOffset = numBatch * gatheredBatch + n5 * block;
                            dst.copyFrom((int)dstOffset, $this$gather, (int)srcOffset, (int)(srcOffset + block));
                            ++index$iv;
                        }
                    }
                    pointer.setLinearIndex(0);
                    ++numBatch;
                }
                break;
            }
            case 2: {
                IntNDArray cfr_ignored_1 = (IntNDArray)indices;
                IntPointer pointer = IntTiledArray.pointer$default(((IntNDArray)indices).getArray(), 0, 1, null);
                int n = numBlocks;
                for (int numBatch = 0; numBatch < n; ++numBatch) {
                    int offset$iv;
                    int[] block$iv;
                    int index = 0;
                    IntPointer $this$forEach$iv = pointer;
                    boolean $i$f$forEach = false;
                    for (int end$iv = indicesSize; end$iv > 0; end$iv -= block$iv.length - offset$iv) {
                        int n6;
                        block$iv = $this$forEach$iv.getCurrentBlock();
                        offset$iv = $this$forEach$iv.getIndexInBlock();
                        IntPointer this_$iv$iv32 = $this$forEach$iv;
                        boolean $i$f$blockIncrement = false;
                        if (this_$iv$iv32.getBlockNum() < this_$iv$iv32.getArray().getBlocksNum() - 1) {
                            IntPointer intPointer = this_$iv$iv32;
                            n6 = intPointer.getBlockNum();
                            intPointer.setBlockNum(n6 + 1);
                            this_$iv$iv32.setIndexInBlock(0);
                            this_$iv$iv32.setCurrentBlock(this_$iv$iv32.getArray().getBlocks()[this_$iv$iv32.getBlockNum()]);
                        } else {
                            this_$iv$iv32.setIndexInBlock(this_$iv$iv32.getArray().getBlockSize());
                        }
                        int this_$iv$iv32 = offset$iv;
                        n6 = block$iv.length;
                        int n7 = offset$iv + end$iv;
                        boolean bl = false;
                        int n8 = Math.min(n6, n7);
                        while (this_$iv$iv32 < n8) {
                            void index$iv;
                            int it = block$iv[index$iv];
                            boolean bl3 = false;
                            int idx = it < 0 ? it + $this$gather.getShape()[actualAxis] : it;
                            int srcOffset = numBatch * dataBatch + idx * block;
                            int n9 = index;
                            index = n9 + 1;
                            int dstOffset = numBatch * gatheredBatch + n9 * block;
                            dst.copyFrom(dstOffset, $this$gather, srcOffset, srcOffset + block);
                            ++index$iv;
                        }
                    }
                    pointer.setLinearIndex(0);
                }
                break;
            }
            default: {
                throw (Throwable)new IllegalStateException("Indices array must have Long or Int type");
            }
        }
        return dst;
    }

    public static /* synthetic */ NDArray gather$default(NDArray nDArray, NDArray nDArray2, int n, int n2, Object object) {
        if ((n2 & 2) != 0) {
            n = 0;
        }
        return GatherKt.gather(nDArray, nDArray2, n);
    }
}

