/*
 * Decompiled with CFR 0.152.
 */
package edu.usc.irds.agepredictor.cmdline.spark.authorage;

import edu.usc.irds.agepredictor.spark.authorage.AgePredictModel;
import java.io.File;
import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
import opennlp.tools.authorage.AgeClassifyME;
import opennlp.tools.authorage.AgeClassifyModel;
import opennlp.tools.cmdline.BasicCmdLineTool;
import opennlp.tools.cmdline.CmdLineUtil;
import opennlp.tools.cmdline.SystemInputStreamFactory;
import opennlp.tools.util.InputStreamFactory;
import opennlp.tools.util.ObjectStream;
import opennlp.tools.util.ParagraphStream;
import opennlp.tools.util.PlainTextByLineStream;
import opennlp.tools.util.featuregen.FeatureGenerator;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.VoidFunction;
import org.apache.spark.ml.feature.CountVectorizerModel;
import org.apache.spark.ml.feature.Normalizer;
import org.apache.spark.ml.linalg.SparseVector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LassoModel;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.ArrayType;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;

public class AgePredictTool
extends BasicCmdLineTool
implements Serializable {
    public String getShortDescription() {
        return "age predictor";
    }

    public String getHelp() {
        return "Usage: bin/authorage " + this.getName() + " [MaxEntModel] RegressionModel Documents";
    }

    public void run(String[] args) {
        AgePredictModel model = null;
        AgeClassifyME classify = null;
        if (args.length == 3) {
            try {
                AgeClassifyModel classifyModel = new AgeClassifyModel(new File(args[0]));
                classify = new AgeClassifyME(classifyModel);
                model = AgePredictModel.readModel(new File(args[1]));
            }
            catch (Exception e) {
                e.printStackTrace();
                return;
            }
        } else if (args.length == 2) {
            try {
                model = AgePredictModel.readModel(new File(args[0]));
            }
            catch (Exception e) {
                e.printStackTrace();
                return;
            }
        } else {
            System.out.println(this.getHelp());
            return;
        }
        ArrayList<Row> data = new ArrayList<Row>();
        SparkSession spark = SparkSession.builder().appName("AgePredict").getOrCreate();
        try {
            String document;
            System.out.println("Please enter your text separted by newline. When done press ctrl+d to terminate system input");
            ParagraphStream documentStream = new ParagraphStream((ObjectStream)new PlainTextByLineStream((InputStreamFactory)new SystemInputStreamFactory(), SystemInputStreamFactory.encoding()));
            FeatureGenerator[] featureGenerators = model.getContext().getFeatureGenerators();
            while ((document = (String)documentStream.read()) != null) {
                String[] tokens = model.getContext().getTokenizer().tokenize(document);
                double[] prob = classify.getProbabilities(tokens);
                String category = classify.getBestCategory(prob);
                ArrayList<String> context = new ArrayList<String>();
                for (FeatureGenerator featureGenerator : featureGenerators) {
                    Collection extractedFeatures = featureGenerator.extractFeatures(tokens);
                    context.addAll(extractedFeatures);
                }
                if (category != null) {
                    for (int i = 0; i < tokens.length / 18; ++i) {
                        context.add("cat=" + category);
                    }
                }
                if (context.size() <= 0) continue;
                data.add(RowFactory.create((Object[])new Object[]{document, context.toArray()}));
            }
        }
        catch (IOException e) {
            e.printStackTrace();
            CmdLineUtil.handleStdinIoError((IOException)e);
        }
        StructType schema = new StructType(new StructField[]{new StructField("document", DataTypes.StringType, false, Metadata.empty()), new StructField("text", (DataType)new ArrayType(DataTypes.StringType, true), false, Metadata.empty())});
        Dataset df = spark.createDataFrame(data, schema);
        CountVectorizerModel cvm = new CountVectorizerModel(model.getVocabulary()).setInputCol("text").setOutputCol("feature");
        Dataset eventDF = cvm.transform(df);
        Normalizer normalizer = ((Normalizer)((Normalizer)new Normalizer().setInputCol("feature")).setOutputCol("normFeature")).setP(1.0);
        JavaRDD normEventDF = normalizer.transform(eventDF).javaRDD();
        final LassoModel linModel = model.getModel();
        normEventDF.foreach((VoidFunction)new VoidFunction<Row>(){

            public void call(Row event) {
                SparseVector sp = (SparseVector)event.getAs("normFeature");
                double prediction = linModel.predict(Vectors.sparse((int)sp.size(), (int[])sp.indices(), (double[])sp.values()));
                System.out.println((String)event.getAs("document"));
                System.out.println("Prediction: " + prediction);
            }
        });
        spark.stop();
    }
}

