package puck.linalg.kernels;

import com.nativelibs4java.opencl.CLContext;
import com.nativelibs4java.opencl.CLDevice;
import com.nativelibs4java.opencl.CLProgram;
import java.util.WeakHashMap;
import scala.Predef$;
import scala.collection.JavaConverters$;
import scala.collection.mutable.MapLike;
import scala.collection.mutable.StringBuilder;
import scala.runtime.BoxesRunTime;
import scala.runtime.RichInt$;
import scala.runtime.RichLong;

/* compiled from: CLMatrixTransposeCopy.scala */
/* loaded from: input_file:puck/linalg/kernels/CLMatrixTransposeCopy$.class */
public final class CLMatrixTransposeCopy$ {
    public static final CLMatrixTransposeCopy$ MODULE$ = null;
    private final WeakHashMap<CLContext, CLMatrixTransposeCopy> map;

    static {
        new CLMatrixTransposeCopy$();
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v1, types: [java.util.WeakHashMap] */
    /* JADX WARN: Type inference failed for: r0v2, types: [java.lang.Throwable] */
    /* JADX WARN: Type inference failed for: r0v37 */
    public CLMatrixTransposeCopy apply(int i, CLContext cLContext) {
        int[] iArr;
        ?? map = map();
        synchronized (map) {
            if (((CLDevice) Predef$.MODULE$.refArrayOps(cLContext.getDevices()).head()).toString().contains("Apple") && ((CLDevice) Predef$.MODULE$.refArrayOps(cLContext.getDevices()).head()).toString().contains("Intel Core")) {
                iArr = new int[]{1, 1, 1};
            } else {
                long unboxToLong = BoxesRunTime.unboxToLong(new RichLong(Predef$.MODULE$.longWrapper(((CLDevice) Predef$.MODULE$.refArrayOps(cLContext.getDevices()).head()).getMaxWorkItemSizes()[0])).min(BoxesRunTime.boxToLong(32)));
                iArr = new int[]{(int) unboxToLong, RichInt$.MODULE$.min$extension(Predef$.MODULE$.intWrapper((int) (((CLDevice) Predef$.MODULE$.refArrayOps(cLContext.getDevices()).head()).getMaxWorkGroupSize() / unboxToLong)), 4), 1};
            }
            int[] iArr2 = iArr;
            CLProgram createProgram = cLContext.createProgram(permuteTransposeCopy(32, iArr2));
            Object orElseUpdate = ((MapLike) JavaConverters$.MODULE$.mapAsScalaMapConverter(map()).asScala()).getOrElseUpdate(cLContext, new CLMatrixTransposeCopy$$anonfun$apply$3(iArr2, createProgram.createKernel("transpose_copy", new Object[0]), createProgram.createKernel("transpose_copy_out", new Object[0])));
            map = map;
            return (CLMatrixTransposeCopy) orElseUpdate;
        }
    }

    public int apply$default$1() {
        return 32;
    }

    private WeakHashMap<CLContext, CLMatrixTransposeCopy> map() {
        return this.map;
    }

    public String permuteTransposeCopy(int i, int[] iArr) {
        return new StringBuilder().append((Object) "\n#define T float\n#define BLOCK_SIZE ").append(BoxesRunTime.boxToInteger(i)).append((Object) "\n\n__attribute__((reqd_work_group_size(").append((Object) Predef$.MODULE$.intArrayOps(iArr).mkString(", ")).append((Object) ")))\n__kernel void transpose_copy(__global T* _dst, int dstOff, int dstMajorStride, \n                             __global T* _src, int srcOff, int srcMajorStride, __global int* srcPtrs,\n                             int srcRows, int colOff, int srcCols) {\n  int numGroupsX = BLOCK_SIZE * get_num_groups(0);\n  int numGroupsY = BLOCK_SIZE * get_num_groups(1);\n  int firstBlockX = BLOCK_SIZE * get_group_id(0);\n  int firstBlockY = BLOCK_SIZE * get_group_id(1);\n  __local float tile[BLOCK_SIZE][BLOCK_SIZE+1];\n\n\n  int threadid = get_local_id(0);\n  int threadidy = get_local_id(1);\n\n  __global T* dst = _dst + dstOff;\n  __global T* src = _src + srcOff;\n\n  for (int yb = firstBlockY; yb < srcCols; yb += numGroupsY) {\n    for (int xb = firstBlockX; xb < srcRows; xb += numGroupsX) {\n      int ylim = min(srcCols, yb + BLOCK_SIZE);\n      int xlim = min(srcRows, xb + BLOCK_SIZE);\n      #pragma unroll\n      for (int y = threadidy + yb; y < ylim; y += get_local_size(1)) {\n       #pragma unroll\n        for(int x = threadid + xb; x < xlim; x += get_local_size(0)) {\n          tile[x-xb][y-yb] = src[srcPtrs[colOff + y]*srcMajorStride + x];\n        }\n      }\n      barrier(CLK_LOCAL_MEM_FENCE);\n      #pragma unroll\n      for (int x = threadidy + xb; x < xlim; x += get_local_size(1)) {\n       #pragma unroll\n        for(int y = yb + threadid; y < ylim; y += get_local_size(0)) {\n          dst[y + x*dstMajorStride] = tile[x-xb][y-yb];\n        }\n      }\n      barrier(CLK_LOCAL_MEM_FENCE);\n    }\n  }\n\n}\n\n__kernel void transpose_copy_out(\n      __global T* _dst, int dstOff, int dstMajorStride,\n       __global int* _dstPtrs, int dstColOffset,\n      __global T* _src, int srcOff, int srcMajorStride, \n      int srcRows, int srcCols) {\n  // copy each col into block[i]\n  __local T block[BLOCK_SIZE][BLOCK_SIZE+1]; // + 1 to avoid bank conflicts\n  __global int* dstPtrs = _dstPtrs + dstColOffset;\n\n  __global T* dst = _dst + dstOff;\n  __global T* src = _src + srcOff;\n\n  int srcCol = get_global_id(0);\n  int threadid = get_local_id(0);\n  int numThreads = get_local_size(0);\n  // srcCol - threadid is the same for all threads in a workgroup.\n  int firstSrcCol = get_group_id(0) * BLOCK_SIZE;\n  int nColsToDo = max(min(BLOCK_SIZE, srcCols - firstSrcCol),0);\n\n  int firstSrcRow = get_global_id(1) * BLOCK_SIZE;\n  int nRowsToDo =  min(BLOCK_SIZE, srcRows - firstSrcRow);\n\n  __local int myPtrs[BLOCK_SIZE];\n  event_t copyFirstPtr = async_work_group_copy(myPtrs, dstPtrs + firstSrcRow, nRowsToDo, 0);\n\n\n  for(int i = 0; i < nColsToDo; ++i) {\n    for(int row = threadid; row < nRowsToDo; row += numThreads) {\n      block[i][row] = src[srcMajorStride * (firstSrcCol + i) + firstSrcRow + row];\n    }\n    //copyInEvents[i] = async_work_group_copy(block[i], // block(i, ::)\n    //  src + srcMajorStride * (firstSrcCol + i) + firstSrcRow, // src(firstSrcRow --> nRowsToDo, myPtrs(i))\n    // nRowsToDo, 0); //\n    //// TODO: why is this necessary on intel? the wait_group_events below doesn't work.\n    //wait_group_events(1, copyInEvents + i);\n  }\n\n\n  wait_group_events(1, &copyFirstPtr);\n  barrier(CLK_LOCAL_MEM_FENCE);\n\n  // each block[i] now contains the slice src(firstSrcRow --> nRowsToDo, firstSrcCol + i)\n  // we want to move src(firstSrcRow, ::) to dst(::, dstPtrs(firstSrcRow))\n  // so we want thread i to write block[i][j] to dst(dstRow, firstSrcRow + j)\n\n  for(int j = 0; j < nRowsToDo && threadid < nColsToDo; j += 1) {\n    dst[myPtrs[j] * dstMajorStride + srcCol] = block[threadid][j];\n  }\n\n\n}\n\n\n                                                                  ").toString();
    }

    private CLMatrixTransposeCopy$() {
        MODULE$ = this;
        this.map = new WeakHashMap<>();
    }
}
