Skip to content
Snippets Groups Projects
Commit 8dbadf36 authored by Shyam Upadhyay's avatar Shyam Upadhyay
Browse files

can train sense still training identifier

parent dd07d307
No related branches found
No related tags found
1 merge request!1Shyam
......@@ -65,7 +65,7 @@ NombankHome = /shared/corpora/corporaWeb/treebanks/eng/nombank/
# The directory of the sentence and pre-extracted features database (~5G of space required)
# Not used during test/working with pre-trained models
# TODO Change this when done
CacheDirectory = /scratch/illinoisSRL/cache
CacheDirectory = /shared/bronte/upadhya3/illinoisSRL/cache
ModelsDirectory = models
......
......@@ -109,13 +109,6 @@
<version>0.4.2</version>
</dependency>
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-simple</artifactId>
<version>1.6.1</version>
<optional>true</optional>
</dependency>
<dependency>
<groupId>org.tartarus</groupId>
<artifactId>snowball</artifactId>
......
......@@ -15,11 +15,13 @@ import edu.illinois.cs.cogcomp.core.utilities.commands.InteractiveShell;
import edu.illinois.cs.cogcomp.infer.ilp.ILPSolverFactory;
import edu.illinois.cs.cogcomp.nlp.corpusreaders.NombankReader;
import edu.illinois.cs.cogcomp.nlp.corpusreaders.PropbankReader;
import edu.illinois.cs.cogcomp.sl.core.SLModel;
import edu.illinois.cs.cogcomp.sl.core.SLParameters;
import edu.illinois.cs.cogcomp.sl.core.SLProblem;
import edu.illinois.cs.cogcomp.sl.learner.Learner;
import edu.illinois.cs.cogcomp.sl.learner.LearnerFactory;
import edu.illinois.cs.cogcomp.sl.learner.l2_loss_svm.L2LossSSVMLearner;
import edu.illinois.cs.cogcomp.sl.util.IFeatureVector;
import edu.illinois.cs.cogcomp.sl.util.WeightVector;
import edu.illinois.cs.cogcomp.srl.caches.FeatureVectorCacheFile;
import edu.illinois.cs.cogcomp.srl.caches.SentenceDBHandler;
......@@ -34,6 +36,8 @@ import edu.illinois.cs.cogcomp.srl.experiment.TextPreProcessor;
import edu.illinois.cs.cogcomp.srl.inference.SRLILPInference;
import edu.illinois.cs.cogcomp.srl.inference.SRLMulticlassInference;
import edu.illinois.cs.cogcomp.srl.jlis.SRLFeatureExtractor;
import edu.illinois.cs.cogcomp.srl.jlis.SRLMulticlassInstance;
import edu.illinois.cs.cogcomp.srl.jlis.SRLMulticlassLabel;
import edu.illinois.cs.cogcomp.srl.learn.IdentifierThresholdTuner;
import edu.illinois.cs.cogcomp.srl.learn.JLISLearner;
import edu.illinois.cs.cogcomp.srl.nom.NomSRLManager;
......@@ -105,12 +109,12 @@ public class Main {
preExtract(srlType, "Identifier");
train(srlType, "Identifier");
tuneIdentifier(srlType);
// tuneIdentifier(srlType);
//
preExtract(srlType, "Classifier");
train(srlType, "Classifier");
// Step 3: Evaluate
//
// // Step 3: Evaluate
evaluate(srlType);
}
......@@ -271,6 +275,7 @@ public class Main {
String allDataCacheFile = properties.getFeatureCacheFile(srlType,
modelToExtract, featureSet, defaultParser, dataset);
System.out.println("reading feature cache from " + allDataCacheFile);
FeatureVectorCacheFile featureCache = preExtract(numConsumers, manager,
modelToExtract, dataset, allDataCacheFile, false);
......@@ -300,9 +305,11 @@ public class Main {
Models modelToExtract, FeatureVectorCacheFile featureCache,
String cacheFile2) throws Exception {
if (IOUtils.exists(cacheFile2)) {
log.warn("Old pruned cache file found. Deleting...");
IOUtils.rm(cacheFile2);
log.info("Done");
log.warn("Old pruned cache file found. Not doing anything...");
return;
// log.warn("Old pruned cache file found. Deleting...");
// IOUtils.rm(cacheFile2);
// log.info("Done");
}
log.info("Pruning features. Saving pruned features to {}", cacheFile2);
......@@ -320,9 +327,24 @@ public class Main {
int numConsumers, SRLManager manager, Models modelToExtract, Dataset dataset,
String cacheFile, boolean lockLexicon) throws Exception {
if (IOUtils.exists(cacheFile)) {
log.warn("Old cache file found. Deleting...");
IOUtils.rm(cacheFile);
log.info("Done");
// log.warn("Old cache file found. Deleting...");
log.warn("Old cache file found. Returning it...");
// IOUtils.rm(cacheFile);
// log.info("Done");
FeatureVectorCacheFile vectorCacheFile = new FeatureVectorCacheFile(cacheFile, modelToExtract, manager);
vectorCacheFile.openReader();
while (vectorCacheFile.hasNext()) {
Pair<SRLMulticlassInstance, SRLMulticlassLabel> pair=vectorCacheFile.next();
IFeatureVector cachedFeatureVector = pair.getFirst().getCachedFeatureVector(modelToExtract);
for (int i : cachedFeatureVector.getIndices()) {
// System.out.printf(i+" ");
manager.getModelInfo(modelToExtract).getLexicon().countFeature(i);
}
}
vectorCacheFile.close();
vectorCacheFile.openReader();
return vectorCacheFile;
}
FeatureVectorCacheFile featureCache = new FeatureVectorCacheFile(cacheFile, modelToExtract, manager);
......@@ -357,6 +379,7 @@ public class Main {
String featureSet = "" + modelInfo.featureManifest.getIncludedFeatures().hashCode();
String cacheFile = properties.getPrunedFeatureCacheFile(srlType, model, featureSet, defaultParser);
System.out.println("In train feat cahce is "+cacheFile);
// AbstractInferenceSolver[] inference = new AbstractInferenceSolver[numThreads];
// for (int i = 0; i < inference.length; i++)
......@@ -370,26 +393,30 @@ public class Main {
log.info("Skipping cross-validation for Classifier. c = {}", c);
}
else {
cache = new FeatureVectorCacheFile(cacheFile, model, manager);
SLProblem cvProblem = cache.getStructuredProblem(20000);
cache.close();
// cache = new FeatureVectorCacheFile(cacheFile, model, manager);
// SLProblem cvProblem = cache.getStructuredProblem(20000);
// cache.close();
// LearnerParameters params = JLISLearner.cvStructSVMSRL(cvProblem, inference, 5);
// c = params.getcStruct();
log.info("c = {} for {} after cv", c, srlType + " " + model);
// log.info("c = {} for {} after cv", c, srlType + " " + model);
}
cache = new FeatureVectorCacheFile(cacheFile, model, manager);
SLModel slmodel = new SLModel();
SLProblem problem = cache.getStructuredProblem();
cache.close();
SLParameters params = new SLParameters();
params.C_FOR_STRUCTURE = (float) c;
// initializeSolver(params);
params.L2_LOSS_SSVM_SOLVER_TYPE= L2LossSSVMLearner.SolverType.ParallelDCDSolver;
params.NUMBER_OF_THREADS = numThreads;
Learner learner = LearnerFactory.getLearner(new SRLMulticlassInference(manager, model), new SRLFeatureExtractor(), params);
SRLMulticlassInference infSolver = new SRLMulticlassInference(manager, model);
Learner learner = LearnerFactory.getLearner(infSolver, new SRLFeatureExtractor(), params);
WeightVector w = learner.train(problem);
JLISLearner.saveWeightVector(w, manager.getModelFileName(model));
JLISLearner.evaluateSRLLabel(infSolver,problem,w);
}
private static void tuneIdentifier(String srlType_) throws Exception {
......
......@@ -94,7 +94,7 @@ public class FeatureVectorCacheFile implements Closeable,
}
}
protected void openReader() throws IOException {
public void openReader() throws IOException {
GZIPInputStream zipin = new GZIPInputStream(new FileInputStream(file));
reader = new BufferedReader(new InputStreamReader(zipin));
}
......
......@@ -10,6 +10,8 @@ import edu.illinois.cs.cogcomp.srl.core.Models;
import edu.illinois.cs.cogcomp.srl.core.SRLManager;
import edu.illinois.cs.cogcomp.srl.jlis.SRLMulticlassInstance;
import edu.illinois.cs.cogcomp.srl.jlis.SRLMulticlassLabel;
import org.apache.log4j.spi.LoggerFactory;
import org.slf4j.Logger;
import java.util.ArrayList;
import java.util.List;
......@@ -31,6 +33,7 @@ public class PruningPreExtractor extends
protected final List<PreExtractRecord> buffer = new ArrayList<PreExtractRecord>();
private AtomicInteger counter = new AtomicInteger();
private Logger log = org.slf4j.LoggerFactory.getLogger(PruningPreExtractor.class);
public PruningPreExtractor(SRLManager manager, Models modelToExtract,
FeatureVectorCacheFile examples, FeatureVectorCacheFile cache,
......@@ -115,7 +118,7 @@ public class PruningPreExtractor extends
cache.put(r.lemma, r.label, r.features);
}
log.info("Saving pruned feature cache done!");
cache.close();
}
}
......@@ -109,11 +109,11 @@ return best;
public float getLoss(IInstance ins, IStructure gold, IStructure pred) {
SRLMulticlassLabel yGold = (SRLMulticlassLabel) gold;
SRLMulticlassLabel ypred= (SRLMulticlassLabel) pred;
double l=0;
float l=0;
if (yGold.getLabel() != ypred.getLabel())
l++;
return 0;
return l;
}
@Override
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment