读代码-TrainClassifier跟TestClassifier
读代码-TrainClassifier和TestClassifier
package org.apache.mahout.classifier.bayes;
public final class TrainClassifier
bayes和cbyes的入口类
两个分支
先设定所有默认参数,如果有非默认项再覆盖
由于参数过多,定义一个类封装参数,便于后续传递
package org.apache.mahout.classifier.bayes;
public final class TestClassifier
入口
分并行和非并行两种实现
package org.apache.mahout.classifier.bayes;
public final class TrainClassifier
bayes和cbyes的入口类
两个分支
public static void trainNaiveBayes(Path dir, Path outputDir, BayesParameters params) throws IOException { BayesDriver driver = new BayesDriver(); driver.runJob(dir, outputDir, params); } public static void trainCNaiveBayes(Path dir, Path outputDir, BayesParameters params) throws IOException { CBayesDriver driver = new CBayesDriver(); driver.runJob(dir, outputDir, params); }
先设定所有默认参数,如果有非默认项再覆盖
由于参数过多,定义一个类封装参数,便于后续传递
BayesParameters params = new BayesParameters(); // Setting all the default parameter values params.setGramSize(1); params.setMinDF(1); params.set("alpha_i","1.0"); params.set("dataSource", "hdfs"); if (cmdLine.hasOption(gramSizeOpt)) { params.setGramSize(Integer.parseInt((String) cmdLine.getValue(gramSizeOpt))); } if (cmdLine.hasOption(minDfOpt)) { params.setMinDF(Integer.parseInt((String) cmdLine.getValue(minDfOpt))); }
Path inputPath = new Path((String) cmdLine.getValue(inputDirOpt)); Path outputPath = new Path((String) cmdLine.getValue(outputOpt)); if ("cbayes".equalsIgnoreCase(classifierType)) { log.info("Training Complementary Bayes Classifier"); trainCNaiveBayes(inputPath, outputPath, params); } else { log.info("Training Bayes Classifier"); // setup the HDFS and copy the files there, then run the trainer trainNaiveBayes(inputPath, outputPath, params); }
package org.apache.mahout.classifier.bayes;
public final class TestClassifier
入口
public static void classifyParallel(BayesParameters params) throws IOException { BayesClassifierDriver.runJob(params); }
分并行和非并行两种实现
if ("sequential".equalsIgnoreCase(classificationMethod)) { classifySequential(params); } else if ("mapreduce".equalsIgnoreCase(classificationMethod)) { classifyParallel(params); }