Skip to content
Snippets Groups Projects
Commit 400d6fac authored by Shyam Upadhyay's avatar Shyam Upadhyay
Browse files

compiles

parent f322db16
No related branches found
No related tags found
1 merge request!1Shyam
......@@ -35,8 +35,6 @@ import edu.illinois.cs.cogcomp.core.utilities.commands.InteractiveShell;
import edu.illinois.cs.cogcomp.infer.ilp.ILPSolverFactory;
import edu.illinois.cs.cogcomp.nlp.corpusreaders.NombankReader;
import edu.illinois.cs.cogcomp.nlp.corpusreaders.PropbankReader;
import edu.illinois.cs.cogcomp.sl.core.StructuredProblem;
import edu.illinois.cs.cogcomp.sl.inference.AbstractInferenceSolver;
import edu.illinois.cs.cogcomp.sl.util.WeightVector;
import edu.illinois.cs.cogcomp.srl.caches.FeatureVectorCacheFile;
import edu.illinois.cs.cogcomp.srl.caches.SentenceDBHandler;
......@@ -372,7 +370,7 @@ public class Main {
// for (int i = 0; i < inference.length; i++)
// inference[i] = new SRLMulticlassInference(manager, model);
double c;
double c=0.01;
FeatureVectorCacheFile cache;
if (model == Models.Classifier) {
......
......@@ -4,6 +4,8 @@ import edu.illinois.cs.cogcomp.annotation.AnnotatorException;
import edu.illinois.cs.cogcomp.core.datastructures.ViewNames;
import edu.illinois.cs.cogcomp.core.datastructures.textannotation.*;
import edu.illinois.cs.cogcomp.edison.utilities.WordNetManager;
import edu.illinois.cs.cogcomp.infer.ilp.ILPSolverFactory;
import edu.illinois.cs.cogcomp.infer.ilp.ILPSolverFactory.SolverType;
import edu.illinois.cs.cogcomp.srl.core.Models;
import edu.illinois.cs.cogcomp.srl.core.SRLManager;
import edu.illinois.cs.cogcomp.srl.core.SRLType;
......@@ -11,6 +13,7 @@ import edu.illinois.cs.cogcomp.srl.experiment.TextPreProcessor;
import edu.illinois.cs.cogcomp.srl.inference.ISRLInference;
import edu.illinois.cs.cogcomp.srl.inference.SRLILPInference;
import edu.illinois.cs.cogcomp.srl.inference.SRLMulticlassInference;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
......@@ -143,10 +146,10 @@ public class SemanticRoleLabeler implements Annotator {
if (predicates.isEmpty())
return null;
// ISRLInference inference = new SRLILPInference(manager, ta,
// predicates, true, properties.getMaxInferenceRounds());
ISRLInference SRLMulticlassInference = new SRLMulticlassInference(manager,);
ILPSolverFactory s = new ILPSolverFactory(SolverType.Gurobi);
ISRLInference inference = new SRLILPInference(s, manager, predicates);
// ISRLInference SRLMulticlassInference = new
// SRLMulticlassInference(manager,);
return inference.getOutputView();
}
......
......@@ -4,8 +4,6 @@ import edu.illinois.cs.cogcomp.core.datastructures.Pair;
import edu.illinois.cs.cogcomp.sl.core.IInstance;
import edu.illinois.cs.cogcomp.sl.core.IStructure;
import edu.illinois.cs.cogcomp.sl.core.SLProblem;
import edu.illinois.cs.cogcomp.sl.core.StructuredProblem;
import edu.illinois.cs.cogcomp.sl.inference.AbstractInferenceSolver;
import edu.illinois.cs.cogcomp.sl.util.WeightVector;
import edu.illinois.cs.cogcomp.srl.learn.CrossValidationHelper.DatasetSplitter;
import edu.illinois.cs.cogcomp.srl.learn.CrossValidationHelper.PerformanceMeasureAverager;
......@@ -65,71 +63,71 @@ public class JLISCVHelper {
return splitIdStarts;
}
public static LearnerParameters cvSSVMSerial(
AbstractInferenceSolver[] inference,
SLProblem sp, Tester<SLProblem> evaluator,
int nFolds) throws Exception {
log.info("Cross validation for struct SVM");
int[] structSplits = getSplitLocations(nFolds, sp.size());
StructureProblemSplitter splitter = new StructureProblemSplitter(
structSplits);
CrossValidationHelper<SLProblem> cvHelper = new CrossValidationHelper<SLProblem>(
nFolds, inference, new RealMeasureAverager(), splitter,
new SSVMTrainer(), evaluator);
List<LearnerParameters> params = new ArrayList<LearnerParameters>();
for (int i = -8; i < 0; i++) {
params.add(LearnerParameters.getSSVMParams(Math.pow(2d, i)));
}
LearnerParameters learnerParameters = cvHelper.doCV(sp, params, false);
return learnerParameters;
}
public static LearnerParameters cvSSVM(
AbstractInferenceSolver[] inference,
SLProblem sp, Tester<SLProblem> evaluator,
int nThreads, int nFolds) throws Exception {
log.info("Cross validation for struct SVM");
int[] structSplits = getSplitLocations(nFolds, sp.size());
StructureProblemSplitter splitter = new StructureProblemSplitter(
structSplits);
CrossValidationHelper<SLProblem> cvHelper = new CrossValidationHelper<SLProblem>(
nFolds, inference, new RealMeasureAverager(), splitter,
new SSVMTrainer(), evaluator);
List<LearnerParameters> params = new ArrayList<LearnerParameters>();
for (int i = -8; i < 0; i++) {
params.add(LearnerParameters.getSSVMParams(Math.pow(2d, i)));
}
LearnerParameters learnerParameters = cvHelper.doCV(sp, params);
return learnerParameters;
}
// public static LearnerParameters cvSSVMSerial(
// AbstractInferenceSolver[] inference,
// SLProblem sp, Tester<SLProblem> evaluator,
// int nFolds) throws Exception {
//
// log.info("Cross validation for struct SVM");
// int[] structSplits = getSplitLocations(nFolds, sp.size());
//
// StructureProblemSplitter splitter = new StructureProblemSplitter(
// structSplits);
//
// CrossValidationHelper<SLProblem> cvHelper = new CrossValidationHelper<SLProblem>(
// nFolds, inference, new RealMeasureAverager(), splitter,
// new SSVMTrainer(), evaluator);
//
// List<LearnerParameters> params = new ArrayList<LearnerParameters>();
//
// for (int i = -8; i < 0; i++) {
// params.add(LearnerParameters.getSSVMParams(Math.pow(2d, i)));
// }
//
// LearnerParameters learnerParameters = cvHelper.doCV(sp, params, false);
//
// return learnerParameters;
// }
// public static LearnerParameters cvSSVM(
// AbstractInferenceSolver[] inference,
// SLProblem sp, Tester<SLProblem> evaluator,
// int nThreads, int nFolds) throws Exception {
//
// log.info("Cross validation for struct SVM");
// int[] structSplits = getSplitLocations(nFolds, sp.size());
//
// StructureProblemSplitter splitter = new StructureProblemSplitter(
// structSplits);
//
// CrossValidationHelper<SLProblem> cvHelper = new CrossValidationHelper<SLProblem>(
// nFolds, inference, new RealMeasureAverager(), splitter,
// new SSVMTrainer(), evaluator);
//
// List<LearnerParameters> params = new ArrayList<LearnerParameters>();
//
// for (int i = -8; i < 0; i++) {
// params.add(LearnerParameters.getSSVMParams(Math.pow(2d, i)));
// }
//
// LearnerParameters learnerParameters = cvHelper.doCV(sp, params);
//
// return learnerParameters;
// }
}
class SSVMTrainer implements Trainer<SLProblem> {
@Override
public WeightVector train(SLProblem dataset,
LearnerParameters params,
AbstractInferenceSolver[] inference) throws Exception {
return JLISLearner.trainStructSVM(inference, dataset,
params.getcStruct());
}
}
//class SSVMTrainer implements Trainer<SLProblem> {
//
// @Override
// public WeightVector train(SLProblem dataset,
// LearnerParameters params,
// AbstractInferenceSolver[] inference) throws Exception {
// return JLISLearner.trainStructSVM(inference, dataset,
// params.getcStruct());
// }
//
//}
abstract class SingleListDatasetSplitter<DatasetType> implements
DatasetSplitter<DatasetType> {
......
......@@ -18,7 +18,6 @@ import org.slf4j.LoggerFactory;
import java.io.FileNotFoundException;
import java.io.IOException;
public class JLISLearner {
private final static Logger log = LoggerFactory
......@@ -35,59 +34,56 @@ public class JLISLearner {
return WeightVectorUtils.load(modelName);
}
public static WeightVector trainStructSVM(
AbstractInferenceSolver[] inference,
SLProblem SLProblem, float c) throws Exception {
//L2LossSSVMParalleDCDSolver learner = new L2LossSSVMParalleDCDSolver();
return learner.
return learner.parallelTrainStructuredSVM(inference, SLProblem, params);
}
public static LearnerParameters cvStructSVMSRL(SLProblem problem,
AbstractInferenceSolver[] inference, int nFolds)
throws Exception {
Tester<SLProblem> evaluator = new Tester<SLProblem>() {
@Override
public PerformanceMeasure evaluate(SLProblem testSet,
WeightVector weight,
AbstractInferenceSolver inference)
throws Exception {
double p = JLISLearner.evaluateSRLLabel(inference, testSet,
weight);
return new JLISCVHelper.RealMeasure(p);
}
};
LearnerParameters bestParams = JLISCVHelper.cvSSVM(inference, problem,
evaluator, inference.length, nFolds);
return bestParams;
}
public static LearnerParameters cvStructSVM(SLProblem problem,
AbstractInferenceSolver[] inference, int nFolds,
Tester<SLProblem> evaluator) throws Exception {
LearnerParameters bestParams = JLISCVHelper.cvSSVM(inference, problem,
evaluator, inference.length, nFolds);
return bestParams;
}
public static double evaluateSRLLabel(
AbstractInferenceSolver inference,
// public static WeightVector trainStructSVM(
// AbstractInferenceSolver[] inference, SLProblem SLProblem, float c)
// throws Exception {
//
// // L2LossSSVMParalleDCDSolver learner = new
// // L2LossSSVMParalleDCDSolver();
//
// // return learner.parallelTrainStructuredSVM(inference, SLProblem,
// // params);
// }
// public static LearnerParameters cvStructSVMSRL(SLProblem problem,
// AbstractInferenceSolver[] inference, int nFolds) throws Exception {
// Tester<SLProblem> evaluator = new Tester<SLProblem>() {
//
// @Override
// public PerformanceMeasure evaluate(SLProblem testSet,
// WeightVector weight, AbstractInferenceSolver inference)
// throws Exception {
//
// double p = JLISLearner.evaluateSRLLabel(inference, testSet,
// weight);
//
// return new JLISCVHelper.RealMeasure(p);
//
// }
// };
//
// LearnerParameters bestParams = JLISCVHelper.cvSSVM(inference, problem,
// evaluator, inference.length, nFolds);
//
// return bestParams;
// }
// public static LearnerParameters cvStructSVM(SLProblem problem,
// AbstractInferenceSolver[] inference, int nFolds,
// Tester<SLProblem> evaluator) throws Exception {
// LearnerParameters bestParams = JLISCVHelper.cvSSVM(inference, problem,
// evaluator, inference.length, nFolds);
//
// return bestParams;
// }
public static double evaluateSRLLabel(AbstractInferenceSolver inference,
SLProblem testSet, WeightVector weights) throws Exception {
EvaluationRecord evalRecord = new EvaluationRecord();
for (int i = 0; i < testSet.input_list.size(); i++) {
IInstance x = testSet.input_list.get(i);
for (int i = 0; i < testSet.instanceList.size(); i++) {
IInstance x = testSet.instanceList.get(i);
SRLMulticlassLabel gold = (SRLMulticlassLabel) testSet.output_list
SRLMulticlassLabel gold = (SRLMulticlassLabel) testSet.goldStructureList
.get(i);
SRLMulticlassLabel bestStructure = (SRLMulticlassLabel) inference
......@@ -108,23 +104,23 @@ public class JLISLearner {
return evalRecord.getF1();
}
private static void initializeSolver(SLParameters params) {
// how precisely should the dual be solved
params.BINARY_DUAL_GAP = 0.1;
params.DUAL_GAP = 0.5;
params.TRAINMINI = true;
params.TRAINMINI_SIZE = 5000;
params.verbose_level = SLParameters.VLEVEL_MID;
params.MAX_SVM_ITER = 500;
// params.CLEAN_CACHE = false;
params.MAX_OUT_ITER = 25;
params.CALCULATE_REAL_OBJ = true;
}
// private static void initializeSolver(SLParameters params) {
//
// // how precisely should the dual be solved
// params.BINARY_DUAL_GAP = 0.1;
// params.DUAL_GAP = 0.5;
//
// params.TRAINMINI = true;
// params.TRAINMINI_SIZE = 5000;
//
// params.verbose_level = SLParameters.VLEVEL_MID;
//
// params.MAX_SVM_ITER = 500;
// // params.CLEAN_CACHE = false;
// params.MAX_OUT_ITER = 25;
//
// params.CALCULATE_REAL_OBJ = true;
//
// }
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment