/** * Load the embedding from a given arff file. First converts the ARFF to a temporary CSV file and * continues the loading mechanism with the CSV file afterwards * * @param path Path to the ARFF file */ private void loadEmbeddingFromArff(String path) { // Try loading ARFF file try { Instances insts = new Instances(new FileReader(path)); CSVSaver saver = new CSVSaver(); saver.setFieldSeparator(" "); saver.setInstances(insts); final File tmpFile = Paths.get(System.getProperty("java.io.tmpdir"), UUID.randomUUID().toString(), ".csv") .toFile(); saver.setFile(tmpFile); saver.setNoHeaderRow(true); saver.writeBatch(); loadEmbeddingFromCSV(tmpFile); tmpFile.delete(); } catch (Exception e) { throw new RuntimeException( "ARFF file could not be read (" + wordVectorLocation.getAbsolutePath() + ")", e); } }
public void writePredictions(Instances ins, String filePrefix) { try { BufferedWriter writer = new BufferedWriter(new FileWriter(outputDir + "/" + filePrefix + ".arff")); writer.write(ins.toString()); writer.newLine(); writer.flush(); writer.close(); CSVSaver s = new CSVSaver(); s.setFile(new File(outputDir + "/" + filePrefix + ".tsv")); s.setInstances(ins); s.setFieldSeparator("\t"); s.writeBatch(); } catch (IOException e) { // TODO Auto-generated catch block e.printStackTrace(); } }
public int writePredictions(Instances ins, String filePrefix) { try { System.out.println("Trying to create the following files:"); System.out.println(outputDir+ "/" + filePrefix + ".arff"); System.out.println(outputDir+ "/" + filePrefix + ".tsv"); BufferedWriter writer = new BufferedWriter(new FileWriter(outputDir + "/" + filePrefix + ".arff")); writer.write(ins.toString()); writer.newLine(); writer.flush(); writer.close(); CSVSaver s = new CSVSaver(); s.setFile(new File(outputDir + "/" + filePrefix + ".tsv")); s.setInstances(ins); s.setFieldSeparator("\t"); s.writeBatch(); } catch (IOException e) { // TODO Auto-generated catch block e.printStackTrace(); return 1; } return 0; }
/** * Save @param data to the CSV file at @param path */ public static void saveDataToCsvFile(String path, Instances data) throws IOException{ System.out.println("\nSaving to file " + path + "..."); CSVSaver saver = new CSVSaver(); saver.setInstances(data); saver.setFile(new File(path)); saver.writeBatch(); }
public static LinearRegressionSummary createCommonPrediction(final String productID) throws IOException, GitAPIException { logger.info("productID = {}", productID); final Set<RetailAnalytics> set = getAllRetailAnalytics(RETAIL_ANALYTICS_ + productID) .filter(ra -> productID.isEmpty() || ra.getProductId().equals(productID)) //.filter(ra -> ra.getShopSize() == 100 || ra.getShopSize() == 500 || ra.getShopSize() == 1_000 || ra.getShopSize() == 10_000 || ra.getShopSize() == 100_000) // .filter(ra -> ra.getShopSize() > 0) // .filter(ra -> ra.getSellVolumeNumber() > 0) // .filter(ra -> ra.getDemography() > 0) // .filter(ra -> ra.getMarketIdx().isEmpty() || ra.getMarketIdx().equals("E")) .collect(toSet()); logger.info("set.size() = {}", set.size()); if (!set.isEmpty()) { //группируем аналитику по товарам и сохраняем // final Map<String, List<RetailAnalytics>> retailAnalyticsHist = set.parallelStream() // .filter(ra -> ra.getNotoriety() >= 100) // .collect(Collectors.groupingBy(RetailAnalytics::getProductId)); // final ExclusionStrategy es = new HistAnalytExclStrat(); // for (final Map.Entry<String, List<RetailAnalytics>> entry : retailAnalyticsHist.entrySet()) { // final String fileNamePath = GitHubPublisher.localPath + RetailSalePrediction.predict_retail_sales + File.separator // + RetailSalePrediction.RETAIL_ANALYTICS_HIST + File.separator + entry.getKey() + ".json"; // Utils.writeToGson(fileNamePath, squeeze(entry.getValue()), es); // } final Set<String> productIds = set.parallelStream().map(RetailAnalytics::getProductId).collect(Collectors.toSet()); final Set<String> productCategories = set.parallelStream().map(RetailAnalytics::getProductCategory).collect(Collectors.toSet()); try { logger.info("createTrainingSet"); final Instances trainingSet = createTrainingSet(set, productIds, productCategories); // final Standardize standardize = new Standardize(); // standardize.setInputFormat(trainingSetRaw); // final Instances trainingSet = Filter.useFilter(trainingSetRaw, standardize); logger.info("ArffSaver"); final ArffSaver saver = new ArffSaver(); saver.setInstances(trainingSet); saver.setFile(new File(Utils.getDir() + WEKA + File.separator + "common_" + productID + ".arff")); saver.writeBatch(); logger.info("CSVSaver"); final CSVSaver saverCsv = new CSVSaver(); saverCsv.setInstances(trainingSet); saverCsv.setFile(new File(Utils.getDir() + WEKA + File.separator + "common_" + productID + ".csv")); saverCsv.writeBatch(); // final File file = new File(GitHubPublisher.localPath + RetailSalePrediction.predict_retail_sales + File.separator + WEKA + File.separator + "common.arff"); // file.delete(); final LinearRegressionSummary summary = trainLinearRegression(trainingSet, productID); // trainRandomCommittee(trainingSet); // trainDecisionTable(trainingSet); // trainMultilayerPerceptron(trainingSet); // trainRandomForest(trainingSet); // trainRandomTree(trainingSet); // trainLibSvm(trainingSet); // logger.info("begin trainJ48BySet"); // trainJ48BySet(trainingSet); // logger.info("end trainJ48BySet"); // // logger.info("begin trainJ48CrossValidation"); // trainJ48CrossValidation(trainingSet); // logger.info("end trainJ48CrossValidation"); //запоминаем дату обновления данных final DateFormat df = new SimpleDateFormat("dd.MM.yyyy"); Utils.writeToGson(GitHubPublisher.localPath + RetailSalePrediction.predict_retail_sales + File.separator + "updateDate.json", new UpdateDate(df.format(new Date()))); return summary; } catch (final Exception e) { logger.info("productID = {}", productID); logger.error(e.getLocalizedMessage(), e); } } return null; }