package puck.parser.gen;

import com.nativelibs4java.opencl.CLContext;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import puck.package$;
import puck.parser.CLBinaryRuleUpdater;
import puck.parser.CLUnaryRuleUpdater;
import puck.parser.RuleSemiring;
import puck.parser.RuleStructure;
import puck.parser.SymId;
import puck.parser.ViterbiRuleSemiring$;

/* loaded from: input_file:puck/parser/gen/SimpleGenRuleMultiply.class */
public class SimpleGenRuleMultiply<C, L> extends JavaFriendlyGenRuleMultiply<C, L> {
    public static final int WARP_SIZE = 32;
    public static final int NUM_WARPS = 48;
    private final GrammarClusterer<C, L> clusterer;
    public RuleStructure<C, L> structure;
    private boolean writeDirectToChart;
    private RuleSemiring semiring;
    public static boolean GRAMMAR_IS_GENERATIVE = true;
    public static boolean NVIDIA_IS_STILL_STUPID = true;
    private static final String WRITE_PARENT_ATOMIC = "     typedef union { int old; float oldf; } intbox;\n     \n#ifndef NVIDIA\n     inline void write_parent_gen_atomic(volatile __global float* loc, float value) {\n        atomic_min((volatile __global int*)loc, *(int*)&value);\n      }\n#else \n     inline void write_parent_gen_atomic(volatile __global float* loc, float value) {\n        volatile __global int* d_ptr = (volatile __global int*)loc;\n        int z = *(int*)&value;\n        asm volatile(\"atom.global.min.s32 %0, [%1], %2;\" : \"=r\"(z), \"+l\"(d_ptr): \"r\"(z));\n      }\n     \n #endif \n     inline void write_parent_atomic(volatile __global float* loc, const float value) {\n       intbox old;\n       old.oldf = *loc;\n       float z = semiring_add(old.oldf, value);\n     \n       while((old.old = atomic_cmpxchg((volatile __global int*)loc, old.old, *(int*)&z)) !=  *(int*)&z) z = semiring_add(old.oldf, value);\n     }\n\n\n";

    public SimpleGenRuleMultiply(RuleStructure<C, L> ruleStructure, boolean z, RuleSemiring ruleSemiring, GrammarClusterer<C, L> grammarClusterer) {
        super(ruleStructure, z);
        this.structure = ruleStructure;
        this.writeDirectToChart = z;
        this.semiring = ruleSemiring;
        this.clusterer = grammarClusterer;
    }

    public List<IndexedUnaryRule<C, L>>[] segmentUnaries(List<IndexedUnaryRule<C, L>> list) {
        return this.clusterer.segmentUnaries(list);
    }

    public List<IndexedBinaryRule<C, L>>[][] segmentBinaries(List<IndexedBinaryRule<C, L>> list) {
        return this.clusterer.segmentBinaries(list);
    }

    @Override // puck.parser.gen.JavaFriendlyGenRuleMultiply
    public CLBinaryRuleUpdater javaBinaryRuleApplication(List<IndexedBinaryRule<C, L>> list, String str, CLContext cLContext, LoopType loopType) {
        ArrayList arrayList = new ArrayList();
        List<IndexedBinaryRule<C, L>>[][] segmentBinaries = segmentBinaries(list);
        boolean supportsExtendedAtomics = supportsExtendedAtomics(cLContext);
        for (int i = 0; i < segmentBinaries.length; i++) {
            arrayList.add(binaryKernelText(str + i, segmentBinaries[i], supportsExtendedAtomics));
        }
        return new CLBinaryRuleUpdater(compileKernels(cLContext, flatten(segmentBinaries), arrayList), loopType.queue(this.structure.numCoarseSyms(), cLContext), this.writeDirectToChart);
    }

    private String binaryKernelText(String str, List<IndexedBinaryRule<C, L>>[] listArr, boolean z) {
        StringBuilder sb = new StringBuilder();
        HashSet hashSet = new HashSet();
        HashSet hashSet2 = new HashSet();
        for (List<IndexedBinaryRule<C, L>> list : listArr) {
            for (SymId<C, L> symId : getParents(list)) {
                if (hashSet.contains(Integer.valueOf(symId.gpu()))) {
                    hashSet2.add(Integer.valueOf(symId.gpu()));
                } else {
                    hashSet.add(Integer.valueOf(symId.gpu()));
                }
            }
        }
        if (!hashSet2.isEmpty() && z) {
            sb.append("#pragma OPENCL EXTENSION cl_khr_global_int32_extended_atomics : enable\n");
        }
        appendAddition(sb);
        sb.append(WRITE_PARENT_ATOMIC);
        sb.append(CLMaskKernels.maskHeader(this.structure.numCoarseSyms()));
        sb.append("\n\n");
        for (List<IndexedBinaryRule<C, L>> list2 : listArr) {
            Collections.sort(list2, new Comparator<IndexedBinaryRule<C, L>>() { // from class: puck.parser.gen.SimpleGenRuleMultiply.1
                @Override // java.util.Comparator
                public int compare(IndexedBinaryRule<C, L> indexedBinaryRule, IndexedBinaryRule<C, L> indexedBinaryRule2) {
                    int compare = Integer.compare(indexedBinaryRule.rule().mo1575parent().gpu(), indexedBinaryRule2.rule().mo1575parent().gpu());
                    if (compare != 0) {
                        return compare;
                    }
                    int compare2 = Integer.compare(indexedBinaryRule.rule().mo1574left().gpu(), indexedBinaryRule2.rule().mo1574left().gpu());
                    return compare2 != 0 ? compare2 : Integer.compare(indexedBinaryRule.rule().mo1573right().gpu(), indexedBinaryRule2.rule().mo1573right().gpu());
                }
            });
        }
        for (int i = 0; i < listArr.length; i++) {
            sb.append("static void subpart" + i + "(const mask_t mask, __global volatile float* parents, __global int* parentIndex, int row, __global float* left, __global float* right, float scale, int numRows) {\n");
            HashMap hashMap = new HashMap();
            HashMap hashMap2 = new HashMap();
            HashMap hashMap3 = new HashMap();
            if (this.writeDirectToChart) {
                sb.append("int pi = parentIndex[row];");
            }
            HashMap hashMap4 = new HashMap();
            Iterator<IndexedBinaryRule<C, L>> it = listArr[i].iterator();
            while (it.hasNext()) {
                int gpu = it.next().rule().mo1575parent().gpu();
                Integer num = (Integer) hashMap4.get(Integer.valueOf(gpu));
                if (num == null) {
                    num = 0;
                }
                hashMap4.put(Integer.valueOf(gpu), Integer.valueOf(num.intValue() + 1));
            }
            int roundUpToMultipleOf = package$.MODULE$.roundUpToMultipleOf(Math.max(this.structure.numNonTerms(), this.structure.numTerms()), 32);
            for (IndexedBinaryRule<C, L> indexedBinaryRule : listArr[i]) {
                int gpu2 = indexedBinaryRule.rule().mo1575parent().gpu();
                String str2 = (String) hashMap.get(Integer.valueOf(gpu2));
                if (str2 == null) {
                    str2 = "parent_" + gpu2;
                    sb.append(String.format("float parent_%d = %s;\n", Integer.valueOf(gpu2), floatToString(this.semiring.zero())));
                    hashMap.put(Integer.valueOf(gpu2), str2);
                }
                int gpu3 = indexedBinaryRule.rule().mo1574left().gpu();
                String str3 = (String) hashMap2.get(Integer.valueOf(gpu3));
                if (str3 == null) {
                    str3 = "left_" + gpu3;
                    sb.append(String.format("float left_%d = left[%d * numRows + row];\n", Integer.valueOf(gpu3), Integer.valueOf(gpu3)));
                    hashMap2.put(Integer.valueOf(gpu3), str3);
                }
                int gpu4 = indexedBinaryRule.rule().mo1573right().gpu();
                String str4 = (String) hashMap3.get(Integer.valueOf(gpu4));
                if (str4 == null) {
                    str4 = "right_" + gpu4;
                    sb.append(String.format("float right_%d = right[%d * numRows + row];\n", Integer.valueOf(gpu4), Integer.valueOf(gpu4)));
                    hashMap3.put(Integer.valueOf(gpu4), str4);
                }
                sb.append(String.format("%s = semiring_mad(%s, %s, %ff);\n", str2, str2, this.semiring.times(str3, str4), Float.valueOf(this.structure.scores()[indexedBinaryRule.ruleId()])));
                hashMap4.put(Integer.valueOf(gpu2), Integer.valueOf(((Integer) hashMap4.get(Integer.valueOf(gpu2))).intValue() - 1));
                if (((Integer) hashMap4.get(Integer.valueOf(gpu2))).intValue() == 0) {
                    if (this.writeDirectToChart) {
                        if (this.semiring.needsScaling()) {
                            sb.append(str2 + " *= scale;");
                        }
                        sb.append(genWriteSymbol(String.format("parents[pi * " + roundUpToMultipleOf + " + %d]", Integer.valueOf(gpu2)), str2, false, z));
                    } else {
                        sb.append(genWriteSymbol(String.format("parents[%d * numRows + row]", Integer.valueOf(gpu2)), str2, !hashSet2.contains(Integer.valueOf(gpu2)), z));
                    }
                }
            }
            sb.append("}\n\n");
        }
        sb.append(String.format("#define NUM_SUB_PARTITIONS %d\n__kernel void %s(__global volatile float* parents,__global const float* parentScale,                  __global int* _parentIndex, int parentOff,                  __global float* left,                  __global const float* leftScale,                  __global int* _leftIndex, int leftOff,                   __global float* right,                  __global const float* rightScale,                  __global int* _rightIndex, int rightOff,                  __global const mask_t* masks, int numRows, int cellsToDo) {\n    int numWorkers = get_global_size(0);\n    int numPartitionGroups = get_num_groups(1);\n    __global int* parentIndex = _parentIndex + parentOff;\n    __global int* leftIndex = _leftIndex + leftOff;\n    __global int* rightIndex = _rightIndex + rightOff;\n    for(int grammarSubPartition = get_group_id(1); grammarSubPartition < NUM_SUB_PARTITIONS; grammarSubPartition += numPartitionGroups) {\n    for (int row = get_global_id(0); row < cellsToDo; row += numWorkers) {\n      const mask_t mask = masks[parentIndex[row]];\n", Integer.valueOf(listArr.length), str));
        sb.append("\n\n");
        if (this.semiring.needsScaling()) {
            sb.append("float scale = native_exp(-parentScale[parentIndex[row]] + rightScale[rightIndex[row]] + leftScale[leftIndex[row]] );");
        } else {
            sb.append("float scale = 1.0f;");
        }
        sb.append("switch (grammarSubPartition) {\n");
        for (int i2 = 0; i2 < listArr.length; i2++) {
            sb.append("case " + i2 + ": subpart" + i2 + "(mask, parents, parentIndex, row, left, right, scale, numRows); continue;\n");
        }
        sb.append("default: continue;\n");
        sb.append("}\n");
        sb.append("}\n");
        sb.append("}\n");
        sb.append("}\n");
        return sb.toString();
    }

    protected String floatToString(float f) {
        return f == Float.NEGATIVE_INFINITY ? "-INFINITY" : f + "f";
    }

    private void appendAddition(StringBuilder sb) {
        sb.append(this.semiring.includes());
    }

    private boolean semiringIsViterbi() {
        return this.semiring instanceof ViterbiRuleSemiring$;
    }

    @Override // puck.parser.gen.JavaFriendlyGenRuleMultiply
    public CLUnaryRuleUpdater javaUnaryRuleApplication(List<IndexedUnaryRule<C, L>> list, String str, CLContext cLContext) {
        ArrayList arrayList = new ArrayList();
        List<IndexedUnaryRule<C, L>>[] segmentUnaries = segmentUnaries(list);
        for (int i = 0; i < segmentUnaries.length; i++) {
            arrayList.add(unaryKernelText(str + i, segmentUnaries[i]));
        }
        return new CLUnaryRuleUpdater(compileKernels(cLContext, Arrays.asList(segmentUnaries), arrayList));
    }

    private String unaryKernelText(String str, List<IndexedUnaryRule<C, L>> list) {
        StringBuilder sb = new StringBuilder();
        appendAddition(sb);
        sb.append("\n\n\n");
        sb.append(String.format(" __kernel void %s(__global volatile float* parents,__global const float* parentScale,                  __global int* _parentIndex,                  int parentOff,  __global float* child, __global const float* childScale,                  __global int* _childIndex,                  int childOff, int numRows, int cellsToDo) {\n    int numWorkers = get_global_size(0);\n    int grammarSubPartition = get_group_id(1);\n    for (int row = get_global_id(0); row < cellsToDo; row += numWorkers) {\n", str));
        sb.append("\n\n");
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        if (this.semiring.needsScaling()) {
            sb.append("__global int* childIndex = _childIndex + childOff;");
            sb.append("__global int* parentIndex = _parentIndex + parentOff;");
            sb.append("float scale = native_exp(-parentScale[parentIndex[row]] + childScale[childIndex[row]]);");
        } else {
            sb.append("float scale = 1.0f;");
        }
        for (IndexedUnaryRule<C, L> indexedUnaryRule : list) {
            int gpu = indexedUnaryRule.rule().mo1575parent().gpu();
            String str2 = (String) hashMap.get(Integer.valueOf(gpu));
            if (str2 == null) {
                str2 = "parent_" + gpu;
                sb.append(String.format("float parent_%d = %s;\n", Integer.valueOf(gpu), floatToString(this.semiring.zero())));
                hashMap.put(Integer.valueOf(gpu), str2);
            }
            int gpu2 = indexedUnaryRule.rule().mo1619child().gpu();
            String str3 = (String) hashMap2.get(Integer.valueOf(gpu2));
            if (str3 == null) {
                str3 = "child_" + gpu2;
                sb.append(String.format("float child_%d = child[%d * numRows + row];\n", Integer.valueOf(gpu2), Integer.valueOf(gpu2)));
                hashMap2.put(Integer.valueOf(gpu2), str3);
            }
            sb.append(String.format("%s = semiring_mad(%s, %s, %ff);\n", str2, str2, str3, Float.valueOf(this.structure.scores()[indexedUnaryRule.ruleId()])));
        }
        sb.append("// write out\n");
        for (Map.Entry entry : hashMap.entrySet()) {
            if (this.semiring.needsScaling()) {
                sb.append(String.format("parents[%d * numRows + row] += %s * scale;\n", entry.getKey(), entry.getValue()));
            } else {
                sb.append(String.format("parents[%d * numRows + row] = %s;\n", entry.getKey(), entry.getValue()));
            }
        }
        sb.append("}\n");
        sb.append("}\n");
        return sb.toString();
    }

    public String genWriteSymbol(String str, String str2, boolean z, boolean z2) {
        return z ? String.format("%s = semiring_add(%s, %s);\n", str, str, str2) : (!semiringIsViterbi() || !GRAMMAR_IS_GENERATIVE || !z2) ? String.format("write_parent_atomic(&%s, %s);\n", str, str2) : String.format("write_parent_gen_atomic(&%s, %s);\n", str, str2);
    }
}
