/** * Builds the classifier for the given training model */ private void initializeModel(CommonConfig config) throws ConfigurationException { // Train the classifier logger.info("Training the classifier..."); File arffFile = new File(modelDir + "/" + this.getClass().getSimpleName() + ".arff"); classifier = new NaiveBayes(); try { Instances data = DataSource.read(arffFile.getAbsolutePath()); data.setClassIndex(data.numAttributes() - 1); classifier.buildClassifier(data); } catch (Exception e) { throw new ConfigurationException(e); } }
/** * <p>To get the distribution of inTrace and outTrace instance in given dataset in <b>path</b>.</p> * @param ins Instances of each project * @throws Exception */ public static void getDist(String path) throws Exception{ Instances ins = DataSource.read(path); int numAttr = ins.numAttributes(); ins.setClassIndex(numAttr-1); int numIns = ins.numInstances(); int intrace = 0; int outtrace = 0; for(int i=0; i<numIns; i++){ if(ins.get(i).stringValue(ins.attribute(ins.classIndex())).equals("InTrace")){ intrace++; }else{ outtrace++; } } System.out.printf("[ %-30s ] inTrace:%4d, outTrace:%4d.\n", path, intrace, outtrace); }
/*** * <p>To get standard attribute list from <b>files/empty.arff</b></p> * @return lsAttr */ public static ArrayList<Attribute> getStandAttrs(){ ArrayList<Attribute> lsAttr = new ArrayList<>(); try { Instances empty = DataSource.read("files/empty.arff"); int numAttr = empty.numAttributes(); // empty.setClassIndex(numAttr - 1); for(int i=0; i<numAttr; i++){ lsAttr.add(empty.attribute(i)); } } catch (Exception e) { System.out.println("reading empty arff error!"); e.printStackTrace(); } return lsAttr; }
/*** * <p>To get 10-fold cross validation in one single arff in <b>path</b></p> * <p>Use C4.5 and <b>SMOTE</b> to classify the dataset.</p> * @param path dataset path * @throws Exception */ public static void getEvalResultbySMOTE(String path, int index) throws Exception{ Instances ins = DataSource.read(path); int numAttr = ins.numAttributes(); ins.setClassIndex(numAttr - 1); SMOTE smote = new SMOTE(); smote.setInputFormat(ins); /** classifiers setting*/ J48 j48 = new J48(); // j48.setConfidenceFactor(0.4f); j48.buildClassifier(ins); FilteredClassifier fc = new FilteredClassifier(); fc.setClassifier(j48); fc.setFilter(smote); Evaluation eval = new Evaluation(ins); eval.crossValidateModel(fc, ins, 10, new Random(1)); // System.out.printf(" %4.3f %4.3f %4.3f", eval.precision(0), eval.recall(0), eval.fMeasure(0)); // System.out.printf(" %4.3f %4.3f %4.3f", eval.precision(1), eval.recall(1), eval.fMeasure(1)); // System.out.printf(" %4.3f \n\n", (1-eval.errorRate())); results[index][0] = eval.precision(0); results[index][1] = eval.recall(0); results[index][2] = eval.fMeasure(0); results[index][3] = eval.precision(1); results[index][4] = eval.recall(1); results[index][5] = eval.fMeasure(1); results[index][6] = 1-eval.errorRate(); }
/** * For testing only. * * @param args expects a dataset as first parameter * @throws Exception if something goes wrong */ public static void main(String[] args) throws Exception { if (args.length != 1) { System.err.println("No dataset supplied!"); System.exit(1); } // load dataset Instances data = DataSource.read(args[0]); // turn Instances into JSON object and output it JSONNode json = toJSON(data); StringBuffer buffer = new StringBuffer(); json.toString(buffer); System.out.println(buffer.toString()); // turn JSON object back into Instances and output it Instances inst = toInstances(json); System.out.println(inst); }
/** * Test to see if the script save feature works. * @throws Exception */ @Test public void testScriptSave() throws Exception { System.out.println("testScriptSave()"); PyScriptClassifier ps = (PyScriptClassifier) getClassifier(); ps.setDebug(true); ps.setPythonFile( new File("scripts/zeror.py") ); ps.setArguments(""); DataSource ds = new DataSource("datasets/iris.arff"); Instances train = ds.getDataSet(); train.setClassIndex( train.numAttributes() - 1 ); ps.setSaveScript(true); ps.buildClassifier(train); // we saved the script so it doesn't matter where it is now ps.setPythonFile( new File("bad-file.py") ); ps.distributionsForInstances(train); }
/** * See if the standardise filter example behaves in the * exact same way as the one in WEKA. * @throws Exception */ @Test public void testStandardise() throws Exception { DataSource ds = new DataSource("datasets/iris.arff"); Instances data = ds.getDataSet(); data.setClassIndex( data.numAttributes() - 1 ); PyScriptFilter filter = new PyScriptFilter(); filter.setPythonFile(new File("scripts/standardise.py")); filter.setInputFormat(data); Instances pyscriptData = Filter.useFilter(data, filter); Standardize filter2 = new Standardize(); filter2.setInputFormat(data); Instances defaultStdData = Filter.useFilter(data, filter2); // test instances for(int x = 0; x < data.numInstances(); x++) { assertTrue( pyscriptData.get(x).toString().equals(defaultStdData.get(x).toString()) ); } }
/** * Make sure that the class attribute was modified * (i.e., ignore class attribute). */ @Test public void testIgnoreClass() throws Exception { DataSource ds = new DataSource("datasets/diabetes_numeric.arff"); Instances data = ds.getDataSet(); data.setClassIndex(data.numAttributes() - 1 ); double[] classAttributeValues = data.attributeToDoubleArray(data.classIndex()); PyScriptFilter filter = new PyScriptFilter(); filter.setPythonFile(new File("scripts/standardise.py")); filter.setIgnoreClass(true); filter.setInputFormat(data); Instances filteredData = Filter.useFilter(data, filter); assert( classAttributeValues[0] != filteredData.attributeToDoubleArray(filteredData.classIndex())[0] ); }
/** * Test that standardise script on Iris works the same * whether or not the class attribute is ignored (since * the class attribute in this case is nominal) * @throws Exception */ @Test public void testIgnoreClassOnIris() throws Exception { DataSource ds = new DataSource("datasets/iris.arff"); Instances data = ds.getDataSet(); data.setClassIndex(data.numAttributes()-1); PyScriptFilter f1 = new PyScriptFilter(); f1.setPythonFile(new File("scripts/standardise.py")); //f1.setIgnoreClass(false); f1.setInputFormat(data); Instances f1Data = Filter.useFilter(data, f1); PyScriptFilter f2 = new PyScriptFilter(); f2.setPythonFile(new File("scripts/standardise.py")); f2.setIgnoreClass(true); f2.setInputFormat(data); Instances f2Data = Filter.useFilter(data, f2); for(int x = 0; x < f1Data.numInstances(); x++) { assertArrayEquals(f1Data.get(x).toDoubleArray(), f2Data.get(x).toDoubleArray(), 1e-6); } }
/** * Test filtered classifier using PyScriptFilter with * PyScriptClassifier. Just seeing if no exceptions * are thrown here. */ @Test public void testFilteredClassifier() throws Exception { DataSource ds = new DataSource("datasets/iris.arff"); Instances data = ds.getDataSet(); data.setClassIndex( data.numAttributes() - 1 ); FilteredClassifier fs = new FilteredClassifier(); PyScriptClassifier pyScriptClassifier = new PyScriptClassifier(); pyScriptClassifier.setPythonFile(new File("scripts/scikit-rf.py")); pyScriptClassifier.setArguments("num_trees=10;"); PyScriptFilter filter = new PyScriptFilter(); filter.setPythonFile(new File("scripts/standardise.py")); fs.setClassifier(pyScriptClassifier); fs.setFilter(filter); fs.buildClassifier(data); fs.distributionsForInstances(data); }
/** * Prints the classifications to the buffer. * * @param classifier the classifier to use for printing the classifications * @param testset the data source to obtain the test instances from * @throws Exception if check fails or error occurs during printing of * classifications */ public void printClassifications(Classifier classifier, DataSource testset) throws Exception { int i; Instances test; Instance inst; i = 0; testset.reset(); if (classifier instanceof BatchPredictor) { test = testset.getDataSet(m_Header.classIndex()); double[][] predictions = ((BatchPredictor) classifier) .distributionsForInstances(test); for (i = 0; i < test.numInstances(); i++) { printClassification(predictions[i], test.instance(i), i); } } else { test = testset.getStructure(m_Header.classIndex()); while (testset.hasMoreElements(test)) { inst = testset.nextElement(test); doPrintClassification(classifier, inst, i); i++; } } }
/** * svm(smo) 분류기를 생성하고 학습시킨다. * * @parameter arff 파일 입력받은 arff 파일로 svm 분류기를 학습시킨다. */ protected SMO createLearnedSVM(String dataset, int subsequenceLength, double lambda) throws Exception { SMO smo = new SMO(); data = DataSource.read(dataset); data.setClassIndex(data.numAttributes() - 1); StringKernel kernel = new StringKernel(data, 250007, subsequenceLength, lambda, debug); smo.setKernel(kernel); smo.buildClassifier(data); return smo; }
@Override public void train(List<MLExample> pTrainExamples) throws Exception { ConfigurationUtil.TrainingMode = true; setPaths(); //This part added since the session was so slow List<Integer> train_example_ids = new ArrayList<Integer>(); for(MLExample example : pTrainExamples) { train_example_ids.add(example.getExampleId()); } WekaFormatConvertor.writeToFile(train_example_ids, trainFile,taskName, new String[]{"1", "2"}); DataSource source = new DataSource(trainFile); Instances data = source.getDataSet(); // setting class attribute if the data format does not provide this information if (data.classIndex() == -1) data.setClassIndex(data.numAttributes() - 1); if(options!=null) wekaAlgorithm.setOptions(options); // set the options wekaAlgorithm.buildClassifier(data); // build classifier // serialize model SerializationHelper.write(modelFile, wekaAlgorithm); }
/** * takes a dataset as first argument * * @param args the commandline arguments * @throws Exception if something goes wrong */ public static void main(String[] args) throws Exception { // load data System.out.println("\n0. Loading data"); DataSource source = new DataSource(System.getProperty("user.dir") + "/data/Arffs/Q010.arff"); Instances data = source.getDataSet(); data.deleteAttributeAt(0); data.setClassIndex(data.numAttributes() - 1); // 1. meta-classifier useClassifier(data); // 2. filter useFilter(data); // 3. low-level useLowLevel(data); }
/** * Loads test data, if required. * * @param data the current training data * @throws Exception if test sets are not compatible with training data */ protected void loadTestData(Instances data) throws Exception { String msg; m_InitialSpaceTestInst = null; if (m_InitialSpaceTestSet.exists() && !m_InitialSpaceTestSet.isDirectory()) { m_InitialSpaceTestInst = DataSource.read(m_InitialSpaceTestSet.getAbsolutePath()); m_InitialSpaceTestInst.setClassIndex(data.classIndex()); msg = data.equalHeadersMsg(m_InitialSpaceTestInst); if (msg != null) throw new IllegalArgumentException("Test set for initial space not compatible with training dta:\n" + msg); m_InitialSpaceTestInst.deleteWithMissingClass(); log("Using test set for initial space: " + m_InitialSpaceTestSet); } m_SubsequentSpaceTestInst = null; if (m_SubsequentSpaceTestSet.exists() && !m_SubsequentSpaceTestSet.isDirectory()) { m_SubsequentSpaceTestInst = DataSource.read(m_SubsequentSpaceTestSet.getAbsolutePath()); m_SubsequentSpaceTestInst.setClassIndex(data.classIndex()); msg = data.equalHeadersMsg(m_SubsequentSpaceTestInst); if (msg != null) throw new IllegalArgumentException("Test set for subsequent sub-spaces not compatible with training dta:\n" + msg); m_SubsequentSpaceTestInst.deleteWithMissingClass(); log("Using test set for subsequent sub-spaces: " + m_InitialSpaceTestSet); } }
/** * Loads test data, if required. * * @param data * the current training data * @throws Exception * if test sets are not compatible with training data */ protected void loadTestData(Instances data) throws Exception { String msg; m_SearchSpaceTestInst = null; if (m_SearchSpaceTestSet.exists() && !m_SearchSpaceTestSet.isDirectory()) { m_SearchSpaceTestInst = DataSource.read(m_SearchSpaceTestSet .getAbsolutePath()); m_SearchSpaceTestInst.setClassIndex(data.classIndex()); msg = data.equalHeadersMsg(m_SearchSpaceTestInst); if (msg != null) { throw new IllegalArgumentException( "Test set for search space not compatible with training dta:\n" + msg); } m_SearchSpaceTestInst.deleteWithMissingClass(); log("Using test set for search space: " + m_SearchSpaceTestSet); } }
public double testForOdds(String odds) { try { String header = "SIGN, B365H, B365D, B365A, BWH, BWD, BWA, IWH, IWD, IWA, LBH, LBD, LBA, PSH, PSD, PSA, WHH, WHD, WHA, SJH, SJD, SJA, VCH, VCD, VCA"; PrintWriter writer = new PrintWriter("/main/resources/tsi/" + folder + "/temp.csv", "UTF-8"); writer.println(header); writer.println(odds); writer.close(); DataSource tempSource = new DataSource(Neural.class.getResourceAsStream("/main/resources/tsi/" + folder + "/temp.csv")); Instances tempSet = tempSource.getDataSet(0); Instance currInstance = tempSet.instance(0); double[] distForInstance = percep.distributionForInstance(currInstance); return distForInstance[0]; } catch (Exception e) { e.printStackTrace(); } return Double.NaN; }
/*** * <p>To Merge the datasets in path array and save the total dataset in dirpath. * </p> * @param path String array of arff file * @throws Exception */ public static void getIns(String[] path, String dirpath) throws Exception{ /** Create a empty dataset total*/ Instances total = new Instances("total3500", getStandAttrs(), 1); total.setClassIndex(total.numAttributes() - 1); int len = path.length; Instances[] temp = new Instances[len]; for(int i=0; i<path.length; i++){ temp[i] = DataSource.read(path[i]); temp[i].setClassIndex(temp[i].numAttributes() - 1); total.addAll(temp[i]); System.out.println("adding " + path[i] + " " + temp[i].numInstances()); // System.out.println("data" + total.numInstances() + "\n"); } String totalName = dirpath+"total3500" + String.valueOf(System.currentTimeMillis()) + ".arff"; DataSink.write(totalName, total); System.out.println("Writing the data into [" + totalName + "] successfully.\n"); }
/*** * <p>To get 10-fold cross validation in one single arff in <b>path</b></p> * <p>Only use C4.5 to classify the dataset.</p> * @param path dataset path * @throws Exception */ public static void getEvalResultbyNo(String path, int index) throws Exception{ Instances ins = DataSource.read(path); int numAttr = ins.numAttributes(); ins.setClassIndex(numAttr - 1); /** classifiers setting*/ J48 j48 = new J48(); // j48.setConfidenceFactor(0.4f); j48.buildClassifier(ins); Evaluation eval = new Evaluation(ins); eval.crossValidateModel(j48, ins, 10, new Random(1)); // System.out.printf(" %4.3f %4.3f %4.3f", eval.precision(0), eval.recall(0), eval.fMeasure(0)); // System.out.printf(" %4.3f %4.3f %4.3f", eval.precision(1), eval.recall(1), eval.fMeasure(1)); // System.out.printf(" %4.3f \n\n", (1-eval.errorRate())); results[index][0] = eval.precision(0); results[index][1] = eval.recall(0); results[index][2] = eval.fMeasure(0); results[index][3] = eval.precision(1); results[index][4] = eval.recall(1); results[index][5] = eval.fMeasure(1); results[index][6] = 1-eval.errorRate(); }
/*** * <p>To get 10-fold cross validation in one single arff in <b>path</b></p> * <p>Use C4.5 and <b>Resampling</b> to classify the dataset.</p> * @param path dataset path * @throws Exception */ public static void getEvalResultbyResampling(String path, int index) throws Exception{ Instances ins = DataSource.read(path); int numAttr = ins.numAttributes(); ins.setClassIndex(numAttr - 1); Resample resample = new Resample(); resample.setInputFormat(ins); /** classifiers setting*/ J48 j48 = new J48(); // j48.setConfidenceFactor(0.4f); j48.buildClassifier(ins); FilteredClassifier fc = new FilteredClassifier(); fc.setClassifier(j48); fc.setFilter(resample); Evaluation eval = new Evaluation(ins); eval.crossValidateModel(fc, ins, 10, new Random(1)); // System.out.printf(" %4.3f %4.3f %4.3f", eval.precision(0), eval.recall(0), eval.fMeasure(0)); // System.out.printf(" %4.3f %4.3f %4.3f", eval.precision(1), eval.recall(1), eval.fMeasure(1)); // System.out.printf(" %4.3f \n\n", (1-eval.errorRate())); results[index][0] = eval.precision(0); results[index][1] = eval.recall(0); results[index][2] = eval.fMeasure(0); results[index][3] = eval.precision(1); results[index][4] = eval.recall(1); results[index][5] = eval.fMeasure(1); results[index][6] = 1-eval.errorRate(); }
/*** * <p>To get 10-fold cross validation in one single arff in <b>path</b></p> * <p>Use C4.5 and <b>Cost-sensitive learning</b> to classify the dataset.</p> * @param path dataset path * @throws Exception */ public static void getEvalResultbyCost(String path, int index) throws Exception{ Instances ins = DataSource.read(path); int numAttr = ins.numAttributes(); ins.setClassIndex(numAttr - 1); /**Classifier setting*/ J48 j48 = new J48(); // j48.setConfidenceFactor(0.4f); j48.buildClassifier(ins); CostSensitiveClassifier csc = new CostSensitiveClassifier(); csc.setClassifier(j48); csc.setCostMatrix(new CostMatrix(new BufferedReader(new FileReader("files/costm")))); Evaluation eval = new Evaluation(ins); eval.crossValidateModel(csc, ins, 10, new Random(1)); // System.out.printf(" %4.3f %4.3f %4.3f", eval.precision(0), eval.recall(0), eval.fMeasure(0)); // System.out.printf(" %4.3f %4.3f %4.3f", eval.precision(1), eval.recall(1), eval.fMeasure(1)); // System.out.printf(" %4.3f \n\n", (1-eval.errorRate())); results[index][0] = eval.precision(0); results[index][1] = eval.recall(0); results[index][2] = eval.fMeasure(0); results[index][3] = eval.precision(1); results[index][4] = eval.recall(1); results[index][5] = eval.fMeasure(1); results[index][6] = 1-eval.errorRate(); }
/*** * <p>To get 10-fold cross validation in one single arff in <b>path</b></p> * <p>Use C4.5 and <b>SMOTE</b> to classify the dataset.</p> * @param path dataset path * @throws Exception */ public static void getEvalResultbyDefault(String path, int index) throws Exception{ Instances ins = DataSource.read(path); int numAttr = ins.numAttributes(); ins.setClassIndex(numAttr - 1); SMOTE smote = new SMOTE(); smote.setInputFormat(ins); /** classifiers setting*/ J48 j48 = new J48(); // j48.setConfidenceFactor(0.4f); j48.buildClassifier(ins); FilteredClassifier fc = new FilteredClassifier(); fc.setClassifier(j48); fc.setFilter(smote); Evaluation eval = new Evaluation(ins); eval.crossValidateModel(fc, ins, 10, new Random(1)); // System.out.printf(" %4.3f %4.3f %4.3f", eval.precision(0), eval.recall(0), eval.fMeasure(0)); // System.out.printf(" %4.3f %4.3f %4.3f", eval.precision(1), eval.recall(1), eval.fMeasure(1)); // System.out.printf(" %4.3f \n\n", (1-eval.errorRate())); results[index][0] = eval.precision(0); results[index][1] = eval.recall(0); results[index][2] = eval.fMeasure(0); results[index][3] = eval.precision(1); results[index][4] = eval.recall(1); results[index][5] = eval.fMeasure(1); results[index][6] = 1-eval.errorRate(); }
public void loadInstanceFromPath(String path) throws Exception { DataSource source = new DataSource(path); Instances data = source.getDataSet(); if (data.classIndex() == -1) { data.setClassIndex(data.numAttributes() - 1); } this.value = data; MiniMLLogger.INSTANCE.info("Dataset loaded with these attributes"); for (int i = 0; i < data.numAttributes(); i++) { MiniMLLogger.INSTANCE.info(data.attribute(i)); } }
/** * Get Instances from ARFF file * * @param fileLocation * path to ARFF file * @return Instances of given ARFF file */ private Instances getInstancesFromARFF(String fileLocation) { Instances instances = null; try { DataSource dataSource = new DataSource(fileLocation); instances = dataSource.getDataSet(); } catch (Exception ex) { System.out.println("Can't find ARFF file at given location: " + fileLocation); } return instances; }
/** * Perform attribute selection with a particular evaluator and a set of * options specifying search method and input file etc. * * @param ASEvaluator an evaluator object * @param options an array of options, not only for the evaluator but also the * search method (if any) and an input data file * @return the results of attribute selection as a String * @exception Exception if no training file is set */ public static String SelectAttributes(ASEvaluation ASEvaluator, String[] options) throws Exception { String trainFileName, searchName; Instances train = null; ASSearch searchMethod = null; String[] optionsTmp = options.clone(); boolean helpRequested = false; try { // get basic options (options the same for all attribute selectors trainFileName = Utils.getOption('i', options); helpRequested = Utils.getFlag('h', optionsTmp); if (helpRequested || (trainFileName.length() == 0)) { searchName = Utils.getOption('s', optionsTmp); if (searchName.length() != 0) { String[] searchOptions = Utils.splitOptions(searchName); searchMethod = (ASSearch) Class.forName(searchOptions[0]).newInstance(); } if (helpRequested) { throw new Exception("Help requested."); } else { throw new Exception("No training file given."); } } } catch (Exception e) { throw new Exception('\n' + e.getMessage() + makeOptionString(ASEvaluator, searchMethod)); } DataSource source = new DataSource(trainFileName); train = source.getDataSet(); return SelectAttributes(ASEvaluator, options, train); }
/** * Prints the classifications to the buffer. * * @param classifier the classifier to use for printing the classifications * @param testset the data source to obtain the test instances from * @throws Exception if check fails or error occurs during printing of * classifications */ public void printClassifications(Classifier classifier, DataSource testset) throws Exception { int i; Instances test; Instance inst; i = 0; testset.reset(); if (classifier instanceof BatchPredictor && ((BatchPredictor) classifier).implementsMoreEfficientBatchPrediction()) { test = testset.getDataSet(m_Header.classIndex()); double[][] predictions = ((BatchPredictor) classifier).distributionsForInstances(test); for (i = 0; i < test.numInstances(); i++) { printClassification(predictions[i], test.instance(i), i); } } else { test = testset.getStructure(m_Header.classIndex()); while (testset.hasMoreElements(test)) { inst = testset.nextElement(test); doPrintClassification(classifier, inst, i); i++; } } }
/** * @param args the command line arguments * @throws java.lang.Exception */ public static void main(String[] args) throws Exception { DataSource source = new DataSource("src/files/letter.arff"); int folds = 10; int runs = 30; Classifier cls = new NaiveBayes(); Instances data = source.getDataSet(); data.setClassIndex(16); System.out.println("#seed \t correctly instances \t percentage of corrects\n"); for (int i = 1; i <= runs; i++) { Evaluation eval = new Evaluation(data); eval.crossValidateModel(cls, data, folds, new Random(i)); System.out.println("#" + i + "\t" + summary(eval)); } }
/** * Test to see if the classifier will work without error * for a peculiar ARFF file. * @throws Exception */ @Test public void testSpecialCharacters() throws Exception { PyScriptClassifier ps = (PyScriptClassifier) getClassifier(); ps.setPythonFile( new File("scripts/zeror.py") ); DataSource ds = new DataSource("datasets/special-chars.arff"); Instances data = ds.getDataSet(); data.setClassIndex( data.numAttributes() - 1); ps.buildClassifier(data); }
/** * Not testing anything in particular here - just make sure * that we can train the RF example without some * exception being thrown. * @throws Exception */ @Test public void testRandomForestOnDiabetes() throws Exception { System.out.println("testRandomForestOnDiabetes()"); PyScriptClassifier ps = (PyScriptClassifier) getClassifier(); ps.setPythonFile( new File("scripts/scikit-rf.py") ); ps.setArguments("num_trees=10"); DataSource ds = new DataSource("datasets/diabetes.arff"); Instances train = ds.getDataSet(); train.setClassIndex( train.numAttributes() - 1 ); ps.buildClassifier(train); }
/** * Not testing anything in particular here - just make sure * that we can train the linear reg example without some * exception being thrown. * @throws Exception */ @Test public void testLinearRegressionOnDiabetes() throws Exception { System.out.println("testLinearRegressionOnDiabetes()"); PyScriptClassifier ps = (PyScriptClassifier) getClassifier(); ps.setPythonFile( new File("scripts/linear-reg.py") ); ps.setArguments("alpha=0.01;epsilon=0.0001"); DataSource ds = new DataSource("datasets/diabetes_numeric.arff"); Instances train = ds.getDataSet(); train.setClassIndex( train.numAttributes() - 1 ); ps.setShouldStandardize(true); ps.buildClassifier(train); }
/** * Not testing anything in particular here - just make sure * that we can train ZeroR on Iris without some exception * being thrown. * @throws Exception */ @Test public void testZeroROnIris() throws Exception { System.out.println("testZeroROnIris()"); PyScriptClassifier ps = (PyScriptClassifier) getClassifier(); ps.setPythonFile( new File("scripts/zeror.py") ); ps.setArguments(""); DataSource ds = new DataSource("datasets/iris.arff"); Instances train = ds.getDataSet(); train.setClassIndex( train.numAttributes() - 1 ); ps.buildClassifier(train); }
/** * Test to see an exception gets thrown when a "bad" * script is given to the classifier. * @throws Exception */ @Test(expected=Exception.class) public void testExceptionRaiser() throws Exception { System.out.println("testExceptionRaiser()"); PyScriptClassifier ps = (PyScriptClassifier) getClassifier(); ps.setDebug(true); ps.setPythonFile( new File("scripts/test-exception.py") ); DataSource ds = new DataSource("datasets/iris.arff"); Instances train = ds.getDataSet(); ps.buildClassifier(train); assertEquals(ps.getModelString(), null); }
@Override protected void setUp() throws Exception { super.setUp(); DataSource ds = new DataSource("tests/filter-test.arff"); m_Instances = ds.getDataSet(); //m_Instances.setClassIndex( m_Instances.numAttributes() - 1); }
/** * Test to see if regression dataset works. For this we're just * making sure that an exception isn't thrown. */ @Test public void testRegressionDataset() throws Exception { DataSource ds = new DataSource("datasets/diabetes_numeric.arff"); Instances data = ds.getDataSet(); data.setClassIndex(data.numAttributes() - 1 ); PyScriptFilter filter = new PyScriptFilter(); filter.setPythonFile(new File("scripts/standardise.py")); filter.setInputFormat(data); Filter.useFilter(data, filter); }
/** * Test to see if the script save feature works. * @throws Exception */ @Test public void testSaveScript() throws Exception { DataSource ds = new DataSource("datasets/iris.arff"); Instances data = ds.getDataSet(); data.setClassIndex( data.numAttributes() - 1 ); PyScriptFilter filter = new PyScriptFilter(); filter.setPythonFile(new File("scripts/standardise.py")); filter.setSaveScript(true); filter.determineOutputFormat(data); // ok, now change the script filter.setPythonFile(new File("not-a-real-file.py")); filter.process( data ); }
/** * 载入并测试文件 * * @param modelPath * @param testPath * @return * @throws Exception */ public void LoadModel(String modelPath, String testPath, String resultFilePath) throws Exception { try { LibD3C c1 = new LibD3C(); ObjectInputStream ois = new ObjectInputStream(new FileInputStream( modelPath)); c1 = (LibD3C) ois.readObject(); ois.close(); DataSource source = new DataSource(testPath); Instances data = source.getDataSet(); data.setClassIndex(data.numAttributes() - 1); BufferedWriter writer = new BufferedWriter(new FileWriter( resultFilePath)); BufferedWriter writePro = new BufferedWriter(new FileWriter( ".probility")); writer.write("predcition " + "origin classs"); writer.newLine(); for (int j = 0; j < data.numInstances(); j++) { writePro.write(String.valueOf(c1.distributionForInstance(data.get(j))[1])); writePro.newLine(); writer.write(String.valueOf(c1.classifyInstance(data.get(j))) + ","); writer.write(String.valueOf(data.get(j).classValue())); writer.newLine(); } writer.flush(); writer.close(); writePro.flush(); writePro.close(); } catch (Exception e) { e.printStackTrace(); } }
/** * Executes the classifier. * * @param prepfeatures the prepared features in arff format * @param modelfile the path to the serialized model * @param clusters the clusters to classify * @return a map of the classified clusters, the keys are the classes * and the values are lists of cluster id's belonging to those classes */ private Map<ClusterClass, List<StoredDomainCluster>> executeClassifier(String prepfeatures, String modelfile, List<StoredDomainCluster> clusters){ Map<ClusterClass, List<StoredDomainCluster>> retval = new HashMap<ClusterClass, List<StoredDomainCluster>>(); try{ DataSource source = new DataSource(new ByteArrayInputStream(prepfeatures.getBytes())); Instances data = source.getDataSet(); if (data.classIndex() == -1){ data.setClassIndex(data.numAttributes() - 1); } String[] options = weka.core.Utils.splitOptions("-p 0"); J48 cls = (J48)weka.core.SerializationHelper.read(modelfile); cls.setOptions(options); for(int i = 0; i < data.numInstances(); i++){ double pred = cls.classifyInstance(data.instance(i)); ClusterClass clusClass = ClusterClass.valueOf( data.classAttribute().value((int)pred).toUpperCase()); if(!retval.containsKey(clusClass)){ retval.put(clusClass, new ArrayList<StoredDomainCluster>()); } retval.get(clusClass).add(clusters.get(i)); } } catch (Exception e) { if(log.isErrorEnabled()){ log.error("Error executing classifier.", e); } } return retval; }
public TrainedModelPredictionMaker(String attributeSelectionObjPath, String modelObjPath, String instancesPath, String classIndex, String predictionPath) { //Go forth and load some instances try { DataSource dataSource = new DataSource(new FileInputStream(instancesPath)); Instances instances = dataSource.getDataSet(); //Make sure to if (instances.classIndex() == -1){ if(classIndex.equals("last")) instances.setClassIndex(instances.numAttributes() - 1); else instances.setClassIndex(Integer.parseInt(classIndex)); } //Load up the attribute selection if we need to if(attributeSelectionObjPath != null){ AttributeSelection as = (AttributeSelection)weka.core.SerializationHelper.read(attributeSelectionObjPath); instances = as.reduceDimensionality(instances); } //Load up yonder classifier AbstractClassifier classifier = (AbstractClassifier)weka.core.SerializationHelper.read(modelObjPath); //Make the evaluation eval = new Evaluation(instances); ClassifierRunner.EvaluatorThread thrd = new ClassifierRunner.EvaluatorThread(eval, classifier, instances, predictionPath); thrd.run(); }catch(Exception e){ throw new RuntimeException(e); } }