/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ctakes.coreference.ae;

import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.ctakes.core.pipeline.PipeBitInfo;
import org.apache.ctakes.core.util.ListFactory;
import org.apache.ctakes.coreference.ae.features.cluster.MentionClusterAgreementFeaturesExtractor;
import org.apache.ctakes.coreference.ae.features.cluster.MentionClusterAttributeFeaturesExtractor;
import org.apache.ctakes.coreference.ae.features.cluster.MentionClusterDepHeadExtractor;
import org.apache.ctakes.coreference.ae.features.cluster.MentionClusterSalienceFeaturesExtractor;
import org.apache.ctakes.coreference.ae.features.cluster.MentionClusterSectionFeaturesExtractor;
import org.apache.ctakes.coreference.ae.features.cluster.MentionClusterSemTypeDepPrefsFeatureExtractor;
import org.apache.ctakes.coreference.ae.features.cluster.MentionClusterStackFeaturesExtractor;
import org.apache.ctakes.coreference.ae.features.cluster.MentionClusterStringFeaturesExtractor;
import org.apache.ctakes.coreference.ae.features.cluster.MentionClusterUMLSFeatureExtractor;
import org.apache.ctakes.coreference.ae.pairing.cluster.ClusterMentionPairer_ImplBase;
import org.apache.ctakes.coreference.ae.pairing.cluster.ClusterPairer;
import org.apache.ctakes.coreference.ae.pairing.cluster.HeadwordPairer;
import org.apache.ctakes.coreference.ae.pairing.cluster.SectionHeaderPairer;
import org.apache.ctakes.coreference.ae.pairing.cluster.SentenceDistancePairer;
import org.apache.ctakes.coreference.util.MarkableUtilities;
import org.apache.ctakes.relationextractor.ae.features.RelationFeaturesExtractor;
import org.apache.ctakes.relationextractor.eval.RelationExtractorEvaluation;
import org.apache.ctakes.typesystem.type.refsem.AnatomicalSite;
import org.apache.ctakes.typesystem.type.refsem.DiseaseDisorder;
import org.apache.ctakes.typesystem.type.refsem.Event;
import org.apache.ctakes.typesystem.type.refsem.Medication;
import org.apache.ctakes.typesystem.type.refsem.Procedure;
import org.apache.ctakes.typesystem.type.refsem.SignSymptom;
import org.apache.ctakes.typesystem.type.relation.CollectionTextRelation;
import org.apache.ctakes.typesystem.type.relation.CollectionTextRelationIdentifiedAnnotationRelation;
import org.apache.ctakes.typesystem.type.relation.CoreferenceRelation;
import org.apache.ctakes.typesystem.type.textsem.AnatomicalSiteMention;
import org.apache.ctakes.typesystem.type.textsem.DiseaseDisorderMention;
import org.apache.ctakes.typesystem.type.textsem.IdentifiedAnnotation;
import org.apache.ctakes.typesystem.type.textsem.Markable;
import org.apache.ctakes.typesystem.type.textsem.MedicationMention;
import org.apache.ctakes.typesystem.type.textsem.ProcedureMention;
import org.apache.ctakes.typesystem.type.textsem.SignSymptomMention;
import org.apache.ctakes.typesystem.type.textspan.Segment;
import org.apache.ctakes.utils.struct.CounterMap;
import org.apache.uima.UimaContext;
import org.apache.uima.analysis_engine.AnalysisEngineDescription;
import org.apache.uima.analysis_engine.AnalysisEngineProcessException;
import org.apache.uima.cas.FeatureStructure;
import org.apache.uima.cas.text.AnnotationFS;
import org.apache.uima.fit.descriptor.ConfigurationParameter;
import org.apache.uima.fit.factory.AnalysisEngineFactory;
import org.apache.uima.fit.util.JCasUtil;
import org.apache.uima.jcas.JCas;
import org.apache.uima.jcas.cas.EmptyFSList;
import org.apache.uima.jcas.cas.FSArray;
import org.apache.uima.jcas.cas.FSList;
import org.apache.uima.jcas.cas.NonEmptyFSList;
import org.apache.uima.jcas.cas.TOP;
import org.apache.uima.jcas.tcas.Annotation;
import org.apache.uima.resource.ResourceInitializationException;
import org.cleartk.ml.CleartkAnnotator;
import org.cleartk.ml.CleartkProcessingException;
import org.cleartk.ml.DataWriter;
import org.cleartk.ml.Feature;
import org.cleartk.ml.Instance;
import org.cleartk.ml.feature.extractor.FeatureExtractor1;
import org.cleartk.util.ViewUriUtil;

@PipeBitInfo(name="Coreference (Clusters)", description="Coreference annotator using mention-synchronous paradigm.", dependencies={PipeBitInfo.TypeProduct.BASE_TOKEN, PipeBitInfo.TypeProduct.SENTENCE, PipeBitInfo.TypeProduct.SECTION, PipeBitInfo.TypeProduct.IDENTIFIED_ANNOTATION, PipeBitInfo.TypeProduct.MARKABLE}, products={PipeBitInfo.TypeProduct.COREFERENCE_RELATION})
public class MentionClusterCoreferenceAnnotator
extends CleartkAnnotator<String> {
    public static final String NO_RELATION_CATEGORY = "-NONE-";
    public static final String PARAM_PROBABILITY_OF_KEEPING_A_NEGATIVE_EXAMPLE = "ProbabilityOfKeepingANegativeExample";
    @ConfigurationParameter(name="ProbabilityOfKeepingANegativeExample", mandatory=false, description="probability that a negative example should be retained for training")
    protected double probabilityOfKeepingANegativeExample = 0.5;
    public static final String PARAM_USE_EXISTING_ENCODERS = "UseExistingEncoders";
    @ConfigurationParameter(name="UseExistingEncoders", mandatory=false, description="Whether to use encoders in output directory during data writing; if we are making multiple calls")
    private boolean useExistingEncoders = false;
    protected Random coin = new Random(0L);
    boolean greedyFirst = true;
    private static DataWriter<String> classDataWriter = null;
    private List<RelationFeaturesExtractor<CollectionTextRelation, IdentifiedAnnotation>> relationExtractors = this.getFeatureExtractors();
    private List<FeatureExtractor1<Markable>> mentionExtractors = this.getMentionExtractors();
    private List<ClusterMentionPairer_ImplBase> pairExtractors = this.getPairExtractors();

    public static AnalysisEngineDescription createDataWriterDescription(Class<? extends DataWriter<String>> dataWriterClass, File outputDirectory, float downsamplingRate) throws ResourceInitializationException {
        return AnalysisEngineFactory.createEngineDescription(MentionClusterCoreferenceAnnotator.class, (Object[])new Object[]{"isTraining", true, PARAM_PROBABILITY_OF_KEEPING_A_NEGATIVE_EXAMPLE, Float.valueOf(downsamplingRate), "dataWriterClassName", dataWriterClass, "outputDirectory", outputDirectory});
    }

    public static AnalysisEngineDescription createAnnotatorDescription(String modelPath) throws ResourceInitializationException {
        return AnalysisEngineFactory.createEngineDescription(MentionClusterCoreferenceAnnotator.class, (Object[])new Object[]{"isTraining", false, "classifierJarPath", modelPath});
    }

    protected List<RelationFeaturesExtractor<CollectionTextRelation, IdentifiedAnnotation>> getFeatureExtractors() {
        ArrayList<RelationFeaturesExtractor<CollectionTextRelation, IdentifiedAnnotation>> extractors = new ArrayList<RelationFeaturesExtractor<CollectionTextRelation, IdentifiedAnnotation>>();
        extractors.add(new MentionClusterAgreementFeaturesExtractor());
        extractors.add(new MentionClusterStringFeaturesExtractor());
        extractors.add(new MentionClusterSectionFeaturesExtractor());
        extractors.add(new MentionClusterUMLSFeatureExtractor());
        extractors.add(new MentionClusterDepHeadExtractor());
        extractors.add(new MentionClusterStackFeaturesExtractor());
        extractors.add(new MentionClusterSalienceFeaturesExtractor());
        extractors.add(new MentionClusterAttributeFeaturesExtractor());
        try {
            extractors.add(new MentionClusterSemTypeDepPrefsFeatureExtractor());
        }
        catch (IOException e) {
            e.printStackTrace();
        }
        return extractors;
    }

    protected List<FeatureExtractor1<Markable>> getMentionExtractors() {
        ArrayList<FeatureExtractor1<Markable>> extractors = new ArrayList<FeatureExtractor1<Markable>>();
        extractors.add(new MentionClusterAgreementFeaturesExtractor());
        extractors.add(new MentionClusterSectionFeaturesExtractor());
        extractors.add(new MentionClusterUMLSFeatureExtractor());
        extractors.add(new MentionClusterDepHeadExtractor());
        extractors.add(new MentionClusterSalienceFeaturesExtractor());
        extractors.add(new MentionClusterAttributeFeaturesExtractor());
        return extractors;
    }

    protected List<ClusterMentionPairer_ImplBase> getPairExtractors() {
        ArrayList<ClusterMentionPairer_ImplBase> pairers = new ArrayList<ClusterMentionPairer_ImplBase>();
        int sentDist = 5;
        pairers.add(new SentenceDistancePairer(sentDist));
        pairers.add(new SectionHeaderPairer(sentDist));
        pairers.add(new ClusterPairer(Integer.MAX_VALUE));
        pairers.add(new HeadwordPairer());
        return pairers;
    }

    protected Iterable<CollectionTextRelationIdentifiedAnnotationPair> getCandidateRelationArgumentPairs(JCas jcas, Markable mention) {
        LinkedHashSet<CollectionTextRelationIdentifiedAnnotationPair> pairs = new LinkedHashSet<CollectionTextRelationIdentifiedAnnotationPair>();
        for (ClusterMentionPairer_ImplBase pairer : this.pairExtractors) {
            pairs.addAll(pairer.getPairs(jcas, mention));
        }
        return pairs;
    }

    private void resetPairers(JCas jcas) {
        for (ClusterMentionPairer_ImplBase pairer : this.pairExtractors) {
            pairer.reset(jcas);
        }
    }

    public void initialize(UimaContext context) throws ResourceInitializationException {
        super.initialize(context);
        if (this.useExistingEncoders && classDataWriter != null) {
            this.dataWriter = classDataWriter;
        } else if (this.isTraining()) {
            classDataWriter = this.dataWriter;
        }
    }

    public void process(JCas jCas) throws AnalysisEngineProcessException {
        this.resetPairers(jCas);
        HashMap<CollectionTextRelationIdentifiedAnnotationPair, CollectionTextRelationIdentifiedAnnotationRelation> relationLookup = new HashMap<CollectionTextRelationIdentifiedAnnotationPair, CollectionTextRelationIdentifiedAnnotationRelation>();
        if (this.isTraining()) {
            for (CollectionTextRelation cluster : JCasUtil.select((JCas)jCas, CollectionTextRelation.class)) {
                for (IdentifiedAnnotation mention : JCasUtil.select((FSList)cluster.getMembers(), Markable.class)) {
                    CollectionTextRelationIdentifiedAnnotationRelation relation = new CollectionTextRelationIdentifiedAnnotationRelation(jCas);
                    relation.setCluster(cluster);
                    relation.setMention(mention);
                    relation.setCategory("CoreferenceClusterMember");
                    relation.addToIndexes();
                    CollectionTextRelationIdentifiedAnnotationPair key = new CollectionTextRelationIdentifiedAnnotationPair(cluster, mention);
                    if (relationLookup.containsKey(key)) {
                        String cat = ((CollectionTextRelationIdentifiedAnnotationRelation)relationLookup.get(key)).getCategory();
                        System.err.println("Error in: " + ViewUriUtil.getURI((JCas)jCas).toString());
                        System.err.println("Error! This attempted relation " + relation.getCategory() + " already has a relation " + cat + " at this span: " + mention.getCoveredText());
                    }
                    relationLookup.put(key, relation);
                }
            }
        }
        for (Segment segment : JCasUtil.select((JCas)jCas, Segment.class)) {
            for (IdentifiedAnnotation mention : JCasUtil.selectCovered((JCas)jCas, Markable.class, (AnnotationFS)segment)) {
                boolean singleton = true;
                double maxScore = 0.0;
                CollectionTextRelation maxCluster = null;
                for (CollectionTextRelationIdentifiedAnnotationPair pair : this.getCandidateRelationArgumentPairs(jCas, (Markable)mention)) {
                    CollectionTextRelation cluster = pair.getCluster();
                    ArrayList<Feature> features = new ArrayList<Feature>();
                    for (RelationFeaturesExtractor<CollectionTextRelation, IdentifiedAnnotation> relationFeaturesExtractor : this.relationExtractors) {
                        List feats = relationFeaturesExtractor.extract(jCas, (Object)cluster, (Object)mention);
                        if (feats == null) continue;
                        features.addAll(feats);
                    }
                    for (FeatureExtractor1 featureExtractor1 : this.mentionExtractors) {
                        features.addAll(featureExtractor1.extract(jCas, (Annotation)mention));
                    }
                    ArrayList dupFeatures = new ArrayList();
                    for (Feature feature : features) {
                        if (feature.getValue() != null) continue;
                        feature.setValue((Object)"NULL");
                        String message = String.format("Null value found in %s from %s", feature, features);
                        System.err.println(message);
                    }
                    features.addAll(dupFeatures);
                    if (this.isTraining()) {
                        String string = this.getRelationCategory(relationLookup, cluster, mention);
                        if (string == null) continue;
                        this.dataWriter.write(new Instance((Object)string, features));
                        if (string.equals(NO_RELATION_CATEGORY)) continue;
                        singleton = false;
                        break;
                    }
                    String string = this.classify(features);
                    Map scores = this.classifier.score(features);
                    if (string.equals(NO_RELATION_CATEGORY)) continue;
                    if (this.greedyFirst) {
                        this.createRelation(jCas, cluster, mention, string, (Double)scores.get(string));
                        singleton = false;
                        break;
                    }
                    if (!((Double)scores.get(string) > maxScore)) continue;
                    maxScore = (Double)scores.get(string);
                    maxCluster = cluster;
                }
                if (!this.isTraining() && !this.greedyFirst && maxCluster != null) {
                    this.createRelation(jCas, maxCluster, mention, "CoreferenceClusterMember", maxScore);
                }
                if (!singleton) continue;
                CollectionTextRelation chain = new CollectionTextRelation(jCas);
                chain.setCategory("Identity");
                NonEmptyFSList list = new NonEmptyFSList(jCas);
                list.setHead((TOP)mention);
                list.setTail((FSList)new EmptyFSList(jCas));
                chain.setMembers((FSList)list);
                chain.addToIndexes();
                list.addToIndexes();
                list.getTail().addToIndexes();
            }
        }
        MentionClusterCoreferenceAnnotator.removeSingletonClusters(jCas);
        MentionClusterCoreferenceAnnotator.createEventClusters(jCas);
    }

    protected String getRelationCategory(Map<CollectionTextRelationIdentifiedAnnotationPair, CollectionTextRelationIdentifiedAnnotationRelation> relationLookup, CollectionTextRelation cluster, IdentifiedAnnotation mention) {
        CollectionTextRelationIdentifiedAnnotationRelation relation = relationLookup.get(new CollectionTextRelationIdentifiedAnnotationPair(cluster, mention));
        String category = relation != null ? relation.getCategory() : (this.coin.nextDouble() <= this.probabilityOfKeepingANegativeExample ? NO_RELATION_CATEGORY : null);
        return category;
    }

    protected String classify(List<Feature> features) throws CleartkProcessingException {
        return (String)this.classifier.classify(features);
    }

    protected void createRelation(JCas jCas, CollectionTextRelation cluster, IdentifiedAnnotation mention, String predictedCategory, Double confidence) {
        CollectionTextRelationIdentifiedAnnotationRelation relation = new CollectionTextRelationIdentifiedAnnotationRelation(jCas);
        relation.setCluster(cluster);
        relation.setMention(mention);
        relation.setCategory(predictedCategory);
        relation.setConfidence(confidence.doubleValue());
        relation.addToIndexes();
        ListFactory.append((JCas)jCas, (FSList)cluster.getMembers(), (TOP)mention);
    }

    private static void createEventClusters(JCas jCas) throws AnalysisEngineProcessException {
        Map<Markable, List<IdentifiedAnnotation>> markable2annotations = MarkableUtilities.indexCoveringUmlsAnnotations(jCas);
        for (CollectionTextRelation cluster : JCasUtil.select((JCas)jCas, CollectionTextRelation.class)) {
            CounterMap headCounts = new CounterMap();
            ArrayList memberList = new ArrayList(JCasUtil.select((FSList)cluster.getMembers(), Markable.class));
            for (Markable member : memberList) {
                Object largest = null;
                for (IdentifiedAnnotation covering : markable2annotations.get(member)) {
                    if (largest != null && covering.getEnd() - covering.getBegin() <= largest.getEnd() - largest.getBegin()) continue;
                    largest = covering;
                }
                if (largest == null) continue;
                headCounts.add(largest.getClass());
            }
            FSArray mentions = new FSArray(jCas, memberList.size());
            IntStream.range(0, memberList.size()).forEach(i -> mentions.set(i, (FeatureStructure)memberList.get(i)));
            Event element = null;
            if (headCounts.size() == 0) {
                element = new Event(jCas);
            } else {
                Class mostCommon = (Class)headCounts.entrySet().stream().sorted(Map.Entry.comparingByValue().reversed()).limit(1L).map(f -> (Class)f.getKey()).collect(Collectors.toList()).get(0);
                if (mostCommon.equals(DiseaseDisorderMention.class)) {
                    element = new DiseaseDisorder(jCas);
                } else if (mostCommon.equals(ProcedureMention.class)) {
                    element = new Procedure(jCas);
                } else if (mostCommon.equals(SignSymptomMention.class)) {
                    element = new SignSymptom(jCas);
                } else if (mostCommon.equals(MedicationMention.class)) {
                    element = new Medication(jCas);
                } else if (mostCommon.equals(AnatomicalSiteMention.class)) {
                    element = new AnatomicalSite(jCas);
                } else {
                    System.err.println("This coreference chain has an unknown type: " + mostCommon.getSimpleName());
                    throw new AnalysisEngineProcessException();
                }
            }
            element.setMentions(mentions);
            element.addToIndexes();
        }
    }

    private static void removeSingletonClusters(JCas jcas) {
        ArrayList<CollectionTextRelation> toRemove = new ArrayList<CollectionTextRelation>();
        for (CollectionTextRelation rel : JCasUtil.select((JCas)jcas, CollectionTextRelation.class)) {
            NonEmptyFSList head = (NonEmptyFSList)rel.getMembers();
            if (!(head.getTail() instanceof EmptyFSList)) continue;
            toRemove.add(rel);
        }
        for (CollectionTextRelation rel : toRemove) {
            rel.removeFromIndexes();
        }
    }

    public Map<RelationExtractorEvaluation.HashableArguments, Double> getMarkablePairScores(JCas jCas) {
        HashMap<RelationExtractorEvaluation.HashableArguments, Double> scoreMap = new HashMap<RelationExtractorEvaluation.HashableArguments, Double>();
        for (CoreferenceRelation reln : JCasUtil.select((JCas)jCas, CoreferenceRelation.class)) {
            RelationExtractorEvaluation.HashableArguments pair = new RelationExtractorEvaluation.HashableArguments(reln.getArg1().getArgument(), reln.getArg2().getArgument());
            scoreMap.put(pair, reln.getConfidence());
        }
        return scoreMap;
    }

    public static class CollectionTextRelationIdentifiedAnnotationPair {
        private final CollectionTextRelation cluster;
        private final IdentifiedAnnotation mention;

        public CollectionTextRelationIdentifiedAnnotationPair(CollectionTextRelation cluster, IdentifiedAnnotation mention) {
            this.cluster = cluster;
            this.mention = mention;
        }

        public final CollectionTextRelation getCluster() {
            return this.cluster;
        }

        public final IdentifiedAnnotation getMention() {
            return this.mention;
        }

        public boolean equals(Object obj) {
            CollectionTextRelationIdentifiedAnnotationPair other = (CollectionTextRelationIdentifiedAnnotationPair)obj;
            return this.cluster == other.cluster && this.mention == other.mention;
        }

        public int hashCode() {
            return 31 * this.cluster.hashCode() + (this.mention == null ? 0 : this.mention.hashCode());
        }
    }
}

