package puck.parser.gen;

import com.nativelibs4java.opencl.CLContext;
import com.nativelibs4java.opencl.CLKernel;
import java.util.Set;
import java.util.zip.ZipFile;
import puck.package$;
import puck.parser.LogSumRuleSemiring$;
import puck.parser.RuleSemiring;
import puck.parser.RuleStructure;
import puck.parser.SymId;
import puck.util.ZipUtil$;
import scala.Array$;
import scala.Function1;
import scala.None$;
import scala.Option;
import scala.Predef$;
import scala.Serializable;
import scala.Some;
import scala.StringContext;
import scala.Tuple2;
import scala.collection.JavaConverters$;
import scala.collection.TraversableLike;
import scala.collection.TraversableOnce;
import scala.collection.immutable.Iterable;
import scala.collection.immutable.Iterable$;
import scala.collection.immutable.Set$;
import scala.collection.mutable.StringBuilder;
import scala.reflect.ClassTag$;
import scala.runtime.BoxesRunTime;
import scala.runtime.RichInt$;

/* compiled from: CLMBRKernels.scala */
/* loaded from: input_file:puck/parser/gen/CLMBRKernels$.class */
public final class CLMBRKernels$ implements Serializable {
    public static final CLMBRKernels$ MODULE$ = null;

    static {
        new CLMBRKernels$();
    }

    public CLMBRKernels read(String str, ZipFile zipFile, CLContext cLContext) {
        return new CLMBRKernels(((int[]) ZipUtil$.MODULE$.deserializeEntry(zipFile.getInputStream(zipFile.getEntry(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"", "/MBRInts"})).s(Predef$.MODULE$.genericWrapArray(new Object[]{str}))))))[0], ZipUtil$.MODULE$.readKernel(zipFile, new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"", "/computeMBRKernel"})).s(Predef$.MODULE$.genericWrapArray(new Object[]{str})), cLContext));
    }

    public <C, L> CLMBRKernels make(RuleStructure<C, L> ruleStructure, CLContext cLContext, RuleSemiring ruleSemiring) {
        return new CLMBRKernels(package$.MODULE$.roundUpToMultipleOf(ruleStructure.numCoarseSyms(), 32) / 32, cLContext.createProgram(programText(RichInt$.MODULE$.max$extension(Predef$.MODULE$.intWrapper(ruleStructure.numNonTerms()), ruleStructure.numTerms()), ruleStructure, ruleSemiring)).createKernel("computeMBR", new Object[0]));
    }

    public <C, L> String genCheckIfMaskIsEmpty(RuleStructure<C, L> ruleStructure, String str, Set<SymId<C, L>> set) {
        return genCheckIfMaskIsEmpty(ruleStructure, str, ((TraversableOnce) JavaConverters$.MODULE$.asScalaSetConverter(set).asScala()).toSet());
    }

    public <C, L> String genCheckIfMaskIsEmpty(RuleStructure<C, L> ruleStructure, String str, scala.collection.immutable.Set<SymId<C, L>> set) {
        return ((Iterable) ((TraversableLike) set.map(new CLMBRKernels$$anonfun$2(ruleStructure), Set$.MODULE$.canBuildFrom())).groupBy((Function1) new CLMBRKernels$$anonfun$1()).withFilter(new CLMBRKernels$$anonfun$3()).map(new CLMBRKernels$$anonfun$4(str), Iterable$.MODULE$.canBuildFrom())).mkString("(!((", ") | (", ")) )");
    }

    public <L, C> String programText(int i, RuleStructure<C, L> ruleStructure, RuleSemiring ruleSemiring) {
        return new StringBuilder().append((Object) "\n    typedef struct {float score; int symbol; int ignoreMe[2];} decode_t;\n      #define NUM_SYMS ").append(BoxesRunTime.boxToInteger(i)).append((Object) "\n\n                                        ").append((Object) Predef$.MODULE$.intArrayOps((int[]) Predef$.MODULE$.intArrayOps(ruleStructure.projectedTerminalMap()).padTo(i, BoxesRunTime.boxToInteger(0), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.Int()))).mkString("__constant int terminalProjections[] = {", ", ", "};")).append((Object) "\n      ").append((Object) Predef$.MODULE$.intArrayOps((int[]) Predef$.MODULE$.intArrayOps(ruleStructure.projectedNonterminalMap()).padTo(i, BoxesRunTime.boxToInteger(0), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.Int()))).mkString("__constant int nonterminalProjections[] = {", ", ", "};")).append((Object) "\n\n// each global_id(0) corresponds to a single sentence.\n// we have some number of workers for each sentence, global_size(1)\n// indices(i) is the first cell in the i'th sentence\n// indices(i+1)-1 is the last cell in the i'th sentence\n// the last cell has the root score.\n//\n/** TODO this isn't optimized at all */\n__kernel void computeMBR(__global decode_t* decodeOut,\n                           __global const float* inside,\n                           __global const float* outside,\n                           __global const int* indices,\n                           __global const int* lengths,\n                           const int numIndices,\n                           int numSyms,\n                           int root) {\n  const int sentence = get_global_id(0);\n  const int firstCell = indices[sentence];\n  const int lastCell = indices[sentence + 1];\n  int length = lengths[sentence];\n  const float root_score = inside[(lastCell-1) * numSyms + root];\n\n   int firstTop = firstCell + (lastCell-firstCell)/2;\n\n  for(int cell = firstCell; cell < lastCell; cell++) {\n    __constant const int* projections = (cell-firstCell >= length) ? nonterminalProjections : terminalProjections;\n\n    int isSingleWordTop = (cell - firstTop <= length);\n\n    __global const float* in = inside + (cell * numSyms);\n    __global const float* out = outside + (cell * numSyms);\n\n    float coarseMargs[").append(BoxesRunTime.boxToInteger(ruleStructure.numCoarseSyms())).append((Object) "];\n    for(int coarseSym = 0; coarseSym < ").append(BoxesRunTime.boxToInteger(ruleStructure.numCoarseSyms())).append((Object) "; ++coarseSym) {\n      coarseMargs[coarseSym] = 0.0f;\n    }\n    for(int sym = 0; sym < NUM_SYMS; ++sym) {\n ").append((Object) (ruleSemiring == LogSumRuleSemiring$.MODULE$ ? "coarseMargs[projections[sym]] += exp(in[sym] - root_score + out[sym]);\n" : "coarseMargs[projections[sym]] += in[sym]/root_score * out[sym];\n")).append((Object) "\n    }\n\n    decode_t myDecode;\n    myDecode.score = -INFINITY;\n    myDecode.symbol = -1;\n    float totalScore = 0.0f;\n    for(int coarseSym = 0; coarseSym < ").append(BoxesRunTime.boxToInteger(ruleStructure.numCoarseSyms())).append((Object) "; ++coarseSym) {\n      totalScore += coarseMargs[coarseSym];\n      if (coarseMargs[coarseSym] > myDecode.score) {\n      \tmyDecode.score = coarseMargs[coarseSym];\n      \tmyDecode.symbol = coarseSym;\n      }\n    }\n\n    //if(isSingleWordTop) myDecode.score = totalScore;\n\n    decodeOut[cell] = myDecode;\n  }\n\n}\n                                                                   ").toString();
    }

    public CLMBRKernels apply(int i, CLKernel cLKernel) {
        return new CLMBRKernels(i, cLKernel);
    }

    public Option<Tuple2<Object, CLKernel>> unapply(CLMBRKernels cLMBRKernels) {
        return cLMBRKernels == null ? None$.MODULE$ : new Some(new Tuple2(BoxesRunTime.boxToInteger(cLMBRKernels.maskSize()), cLMBRKernels.getMasksKernel()));
    }

    private Object readResolve() {
        return MODULE$;
    }

    private CLMBRKernels$() {
        MODULE$ = this;
    }
}
