package epic.sequences;

import breeze.linalg.DenseVector;
import breeze.linalg.DenseVector$;
import breeze.linalg.DenseVector$mcD$sp;
import breeze.linalg.max$;
import breeze.linalg.softmax$;
import breeze.linalg.support.CanTraverseValues$OpArrayDD$;
import breeze.util.Index;
import epic.constraints.LabeledSpanConstraints;
import epic.sequences.SemiCRF;
import epic.trees.Span;
import epic.trees.Span$;
import java.util.Arrays;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import scala.Array$;
import scala.Predef$;
import scala.Tuple2;
import scala.collection.IndexedSeq;
import scala.collection.TraversableOnce;
import scala.collection.immutable.IndexedSeq$;
import scala.math.Ordering$Int$;
import scala.math.package$;
import scala.reflect.ClassTag$;
import scala.runtime.BoxesRunTime;
import scala.runtime.DoubleRef;
import scala.runtime.IntRef;
import scala.runtime.RichInt$;

/* compiled from: SemiCRF.scala */
/* loaded from: input_file:epic/sequences/SemiCRF$Marginal$.class */
public class SemiCRF$Marginal$ {
    public static final SemiCRF$Marginal$ MODULE$ = null;

    static {
        new SemiCRF$Marginal$();
    }

    public <L, W> SemiCRF.Marginal<L, W> maxDerivationMarginal(SemiCRF.Anchoring<L, W> anchoring) {
        return goldMarginal(anchoring, SemiCRF$.MODULE$.viterbi(anchoring, SemiCRF$.MODULE$.viterbi$default$2()).mo1289label());
    }

    public <L, W> SemiCRF.Marginal<L, W> apply(final SemiCRF.Anchoring<L, W> anchoring) {
        final double[][] forwardScores = forwardScores(anchoring);
        final double[][] backwardScores = backwardScores(anchoring);
        final double unboxToDouble = BoxesRunTime.unboxToDouble(softmax$.MODULE$.apply(Predef$.MODULE$.refArrayOps(forwardScores).mo2521last(), softmax$.MODULE$.reduceDouble(CanTraverseValues$OpArrayDD$.MODULE$, max$.MODULE$.reduce_Double(CanTraverseValues$OpArrayDD$.MODULE$))));
        return new SemiCRF.Marginal<L, W>(anchoring, forwardScores, backwardScores, unboxToDouble, anchoring) { // from class: epic.sequences.SemiCRF$Marginal$$anon$2
            private final SemiCRF.Anchoring scorer$1;
            private final double[][] forwardScores$2;
            private final double[][] backwardScore$1;
            private final double partition$1;
            private final SemiCRF.Anchoring _s$1;

            @Override // epic.sequences.SemiCRF.Marginal
            public IndexedSeq<W> words() {
                return SemiCRF.Marginal.Cclass.words(this);
            }

            @Override // epic.sequences.SemiCRF.Marginal
            public int length() {
                return SemiCRF.Marginal.Cclass.length(this);
            }

            @Override // epic.sequences.SemiCRF.Marginal
            public double spanMarginal(int i, int i2, int i3) {
                return SemiCRF.Marginal.Cclass.spanMarginal(this, i, i2, i3);
            }

            @Override // epic.sequences.SemiCRF.Marginal
            public DenseVector<Object> spanMarginal(int i, int i2) {
                return SemiCRF.Marginal.Cclass.spanMarginal(this, i, i2);
            }

            @Override // epic.sequences.SemiCRF.Marginal
            public LabeledSpanConstraints<L> computeSpanConstraints(double d) {
                return SemiCRF.Marginal.Cclass.computeSpanConstraints(this, d);
            }

            @Override // epic.sequences.SemiCRF.Marginal
            public boolean hasSupportOver(SemiCRF.Marginal<L, W> marginal) {
                return SemiCRF.Marginal.Cclass.hasSupportOver(this, marginal);
            }

            @Override // epic.sequences.SemiCRF.Marginal
            public String decode() {
                return SemiCRF.Marginal.Cclass.decode(this);
            }

            @Override // epic.sequences.SemiCRF.Marginal
            public double computeSpanConstraints$default$1() {
                return SemiCRF.Marginal.Cclass.computeSpanConstraints$default$1(this);
            }

            @Override // epic.sequences.SemiCRF.Marginal
            public SemiCRF.Anchoring<L, W> anchoring() {
                return this._s$1;
            }

            @Override // epic.framework.VisitableMarginal
            public void visit(SemiCRF.TransitionVisitor<L, W> transitionVisitor) {
                int size = this.scorer$1.labelIndex().size();
                int length = length();
                while (true) {
                    int i = length - 1;
                    if (i < 0) {
                        return;
                    }
                    int i2 = 0;
                    while (true) {
                        int i3 = i2;
                        if (i3 < size) {
                            int maxSpanLengthStartingAt = anchoring().constraints().maxSpanLengthStartingAt(i) + i;
                            while (true) {
                                int i4 = maxSpanLengthStartingAt;
                                if (i4 > i) {
                                    if (anchoring().constraints().isAllowedSpan(i, i4)) {
                                        int i5 = 0;
                                        while (true) {
                                            int i6 = i5;
                                            if (i6 < size) {
                                                double d = this.backwardScore$1[i4][i6];
                                                if (anchoring().maxSegmentLength(i6) >= i4 - i && d != Double.NEGATIVE_INFINITY) {
                                                    double transitionMarginal = transitionMarginal(i3, i6, i, i4);
                                                    if (transitionMarginal != CMAESOptimizer.DEFAULT_STOPFITNESS) {
                                                        transitionVisitor.visitTransition(i3, i6, i, i4, transitionMarginal);
                                                    }
                                                }
                                                i5 = i6 + 1;
                                            }
                                        }
                                    }
                                    maxSpanLengthStartingAt = i4 - 1;
                                }
                            }
                            i2 = i3 + 1;
                        }
                    }
                    length = i;
                }
            }

            @Override // epic.sequences.SemiCRF.Marginal
            public double transitionMarginal(int i, int i2, int i3, int i4) {
                double d = this.forwardScores$2[i3][i] + this.backwardScore$1[i4][i2];
                return Predef$.MODULE$.double2Double(d).isInfinite() ? CMAESOptimizer.DEFAULT_STOPFITNESS : package$.MODULE$.exp((d + anchoring().scoreTransition(i, i2, i3, i4)) - logPartition());
            }

            @Override // epic.sequences.SemiCRF.Marginal, epic.framework.Marginal
            public double logPartition() {
                return this.partition$1;
            }

            {
                this.scorer$1 = anchoring;
                this.forwardScores$2 = forwardScores;
                this.backwardScore$1 = backwardScores;
                this.partition$1 = unboxToDouble;
                this._s$1 = anchoring;
                SemiCRF.Marginal.Cclass.$init$(this);
            }
        };
    }

    public <L, W> SemiCRF.Marginal<L, W> goldMarginal(SemiCRF.Anchoring<L, W> anchoring, IndexedSeq<Tuple2<L, Span>> indexedSeq) {
        IntRef intRef = new IntRef(anchoring.labelIndex().apply((Index<L>) anchoring.startSymbol()));
        DoubleRef doubleRef = new DoubleRef(CMAESOptimizer.DEFAULT_STOPFITNESS);
        IntRef intRef2 = new IntRef(0);
        int[] iArr = (int[]) Array$.MODULE$.fill(Span$.MODULE$.end$extension(indexedSeq.mo2521last().mo2366_2().encoded()), new SemiCRF$Marginal$$anonfun$1(), ClassTag$.MODULE$.Int());
        int[] iArr2 = (int[]) Array$.MODULE$.fill(Span$.MODULE$.end$extension(indexedSeq.mo2521last().mo2366_2().encoded()), new SemiCRF$Marginal$$anonfun$2(), ClassTag$.MODULE$.Int());
        int[] iArr3 = (int[]) Array$.MODULE$.fill(Span$.MODULE$.end$extension(indexedSeq.mo2521last().mo2366_2().encoded()), new SemiCRF$Marginal$$anonfun$3(), ClassTag$.MODULE$.Int());
        indexedSeq.withFilter(new SemiCRF$Marginal$$anonfun$goldMarginal$1()).foreach(new SemiCRF$Marginal$$anonfun$goldMarginal$2(anchoring, indexedSeq, intRef, doubleRef, intRef2, iArr, iArr2, iArr3));
        return new SemiCRF$Marginal$$anon$3(anchoring, indexedSeq, doubleRef, iArr, iArr2, iArr3, anchoring);
    }

    private <L, W> double[][] forwardScores(SemiCRF.Anchoring<L, W> anchoring) {
        int length = anchoring.length();
        int size = anchoring.labelIndex().size();
        double[][] dArr = (double[][]) Array$.MODULE$.fill(length + 1, size, new SemiCRF$Marginal$$anonfun$4(), ClassTag$.MODULE$.Double());
        dArr[0][anchoring.labelIndex().apply((Index<L>) anchoring.startSymbol())] = 0.0d;
        double[] dArr2 = new double[size * length];
        int i = 1;
        while (true) {
            int i2 = i;
            if (i2 > length) {
                return dArr;
            }
            int i3 = 0;
            while (true) {
                int i4 = i3;
                if (i4 < size) {
                    int i5 = 0;
                    int max = package$.MODULE$.max(i2 - anchoring.maxSegmentLength(i4), 0);
                    while (true) {
                        int i6 = max;
                        if (i6 < i2) {
                            if (anchoring.constraints().isAllowedLabeledSpan(i6, i2, i4)) {
                                if (anchoring.ignoreTransitionModel()) {
                                    double array = softmax$.MODULE$.array(dArr[i6], dArr[i6].length);
                                    if (array != Double.NEGATIVE_INFINITY) {
                                        double scoreTransition = anchoring.scoreTransition(-1, i4, i6, i2) + array;
                                        if (scoreTransition != Double.NEGATIVE_INFINITY) {
                                            dArr2[i5] = scoreTransition;
                                            i5++;
                                        }
                                    }
                                } else {
                                    for (int i7 = 0; i7 < size; i7++) {
                                        double d = dArr[i6][i7];
                                        if (d != Double.NEGATIVE_INFINITY) {
                                            double scoreTransition2 = anchoring.scoreTransition(i7, i4, i6, i2) + d;
                                            if (scoreTransition2 != Double.NEGATIVE_INFINITY) {
                                                dArr2[i5] = scoreTransition2;
                                                i5++;
                                            }
                                        }
                                    }
                                }
                            }
                            max = i6 + 1;
                        }
                    }
                    dArr[i2][i4] = softmax$.MODULE$.array(dArr2, i5);
                    i3 = i4 + 1;
                }
            }
            i = i2 + 1;
        }
    }

    private <L, W> double[][] backwardScores(SemiCRF.Anchoring<L, W> anchoring) {
        int length = anchoring.length();
        int size = anchoring.labelIndex().size();
        double[][] dArr = (double[][]) Array$.MODULE$.fill(length + 1, size, new SemiCRF$Marginal$$anonfun$5(), ClassTag$.MODULE$.Double());
        Arrays.fill(dArr[length], CMAESOptimizer.DEFAULT_STOPFITNESS);
        RichInt$ richInt$ = RichInt$.MODULE$;
        Predef$ predef$ = Predef$.MODULE$;
        double[] dArr2 = new double[size * BoxesRunTime.unboxToInt(((TraversableOnce) richInt$.until$extension0(0, size).map(new SemiCRF$Marginal$$anonfun$6(anchoring), IndexedSeq$.MODULE$.canBuildFrom())).mo2519max(Ordering$Int$.MODULE$))];
        int i = length;
        while (true) {
            int i2 = i - 1;
            if (i2 < 0) {
                return dArr;
            }
            int i3 = 0;
            while (true) {
                int i4 = i3;
                if (i4 < size) {
                    int i5 = 0;
                    int maxSpanLengthStartingAt = anchoring.constraints().maxSpanLengthStartingAt(i2) + i2;
                    while (true) {
                        int i6 = maxSpanLengthStartingAt;
                        if (i6 > i2) {
                            if (anchoring.constraints().isAllowedSpan(i2, i6)) {
                                int i7 = 0;
                                while (true) {
                                    int i8 = i7;
                                    if (i8 < size) {
                                        double d = dArr[i6][i8];
                                        if (anchoring.maxSegmentLength(i8) >= i6 - i2 && d != Double.NEGATIVE_INFINITY) {
                                            double scoreTransition = anchoring.scoreTransition(i4, i8, i2, i6) + d;
                                            if (scoreTransition != Double.NEGATIVE_INFINITY) {
                                                dArr2[i5] = scoreTransition;
                                                i5++;
                                            }
                                        }
                                        i7 = i8 + 1;
                                    }
                                }
                            }
                            maxSpanLengthStartingAt = i6 - 1;
                        }
                    }
                    dArr[i2][i4] = BoxesRunTime.unboxToDouble(softmax$.MODULE$.apply(new DenseVector$mcD$sp(dArr2, 0, 1, i5), softmax$.MODULE$.reduceDouble(DenseVector$.MODULE$.canIterateValues(), max$.MODULE$.reduce_Double(DenseVector$.MODULE$.canIterateValues()))));
                    i3 = i4 + 1;
                }
            }
            i = i2;
        }
    }

    public SemiCRF$Marginal$() {
        MODULE$ = this;
    }
}
