opencv 机器学习算法汇总

转载自http://blog.csdn.net/cnbloger/article/details/78005680

opencv提供了非常多的机器学习算法用于研究。这里对这些算法进行分类学习和研究,以抛砖引玉。这里使用的机器学习算法包括:人工神经网络,boost,决策树,最近邻,逻辑回归,贝叶斯,随机森林,SVM等算法等。

机器学习的过程相同,都要经历1、收集样本数据sampleData2.训练分类器mode3.对测试数据testData进行预测

这里使用一个在别处看到的例子,利用身高体重等原始信息预测男女的概率。通过一些简单的数据学习,用测试数据预测男女概率。

[java] 
view plain
 copy

  1. import org.opencv.core.Core;  
  2. import org.opencv.core.CvType;  
  3. import org.opencv.core.Mat;  
  4. import org.opencv.core.TermCriteria;  
  5. import org.opencv.ml.ANN_MLP;  
  6. import org.opencv.ml.Boost;  
  7. import org.opencv.ml.DTrees;  
  8. import org.opencv.ml.KNearest;  
  9. import org.opencv.ml.LogisticRegression;  
  10. import org.opencv.ml.Ml;  
  11. import org.opencv.ml.NormalBayesClassifier;  
  12. import org.opencv.ml.RTrees;  
  13. import org.opencv.ml.SVM;  
  14. import org.opencv.ml.SVMSGD;  
  15. import org.opencv.ml.TrainData;  
  16.   
  17. public class ML {  
  18.     public static void main(String[] args) {  
  19.         System.loadLibrary(Core.NATIVE_LIBRARY_NAME);  
  20.         // 训练数据,两个维度,表示身高和体重  
  21.         float[] trainingData = { 18680185811605016148 };  
  22.         // 训练标签数据,前两个表示男生0,后两个表示女生1,由于使用了多种机器学习算法,他们的输入有些不一样,所以labelsMat有三种   
  23.         float[] labels = { 0f, 0f, 0f, 0f, 1f, 1f, 1f, 1f };  
  24.         int[] labels2 = { 0011 };  
  25.         float[] labels3 = { 0011 };  
  26.         // 测试数据,先男后女  
  27.         float[] test = { 1847915950 };  
  28.   
  29.         Mat trainingDataMat = new Mat(42, CvType.CV_32FC1);  
  30.         trainingDataMat.put(00, trainingData);  
  31.   
  32.         Mat labelsMat = new Mat(42, CvType.CV_32FC1);  
  33.         labelsMat.put(00, labels);  
  34.   
  35.         Mat labelsMat2 = new Mat(41, CvType.CV_32SC1);  
  36.         labelsMat2.put(00, labels2);  
  37.   
  38.         Mat labelsMat3 = new Mat(41, CvType.CV_32FC1);  
  39.         labelsMat3.put(00, labels3);  
  40.   
  41.         Mat sampleMat = new Mat(22, CvType.CV_32FC1);  
  42.         sampleMat.put(00, test);  
  43.   
  44.         MyAnn(trainingDataMat, labelsMat, sampleMat);  
  45.         MyBoost(trainingDataMat, labelsMat2, sampleMat);  
  46.         MyDtrees(trainingDataMat, labelsMat2, sampleMat);  
  47.         MyKnn(trainingDataMat, labelsMat3, sampleMat);  
  48.         MyLogisticRegression(trainingDataMat, labelsMat3, sampleMat);  
  49.         MyNormalBayes(trainingDataMat, labelsMat2, sampleMat);  
  50.         MyRTrees(trainingDataMat, labelsMat2, sampleMat);  
  51.         MySvm(trainingDataMat, labelsMat2, sampleMat);  
  52.         MySvmsgd(trainingDataMat, labelsMat2, sampleMat);  
  53.     }  
  54.   
  55.     // 人工神经网络  
  56.     public static Mat MyAnn(Mat trainingData, Mat labels, Mat testData) {  
  57.         // train data using aNN  
  58.         TrainData td = TrainData.create(trainingData, Ml.ROW_SAMPLE, labels);  
  59.         Mat layerSizes = new Mat(14, CvType.CV_32FC1);  
  60.         // 含有两个隐含层的网络结构,输入、输出层各两个节点,每个隐含层含两个节点  
  61.         layerSizes.put(00new float[] { 2222 });  
  62.         ANN_MLP ann = ANN_MLP.create();  
  63.         ann.setLayerSizes(layerSizes);  
  64.         ann.setTrainMethod(ANN_MLP.BACKPROP);  
  65.         ann.setBackpropWeightScale(0.1);  
  66.         ann.setBackpropMomentumScale(0.1);  
  67.         ann.setActivationFunction(ANN_MLP.SIGMOID_SYM, 11);  
  68.         ann.setTermCriteria(new TermCriteria(TermCriteria.MAX_ITER + TermCriteria.EPS, 3000.0));  
  69.         boolean success = ann.train(td.getSamples(), Ml.ROW_SAMPLE, td.getResponses());  
  70.         System.out.println(“Ann training result: “ + success);  
  71.         // ann.save(“D:/bp.xml”);//存储模型  
  72.         // ann.load(“D:/bp.xml”);//读取模型  
  73.   
  74.         // 测试数据  
  75.         Mat responseMat = new Mat();  
  76.         ann.predict(testData, responseMat, 0);  
  77.         System.out.println(“Ann responseMat:\n” + responseMat.dump());  
  78.         for (int i = 0; i < responseMat.size().height; i++) {  
  79.             if (responseMat.get(i, 0)[0] + responseMat.get(i, i)[0] >= 1)  
  80.                 System.out.println(“Girl\n”);  
  81.             if (responseMat.get(i, 0)[0] + responseMat.get(i, i)[0] < 1)  
  82.                 System.out.println(“Boy\n”);  
  83.         }  
  84.         return responseMat;  
  85.     }  
  86.   
  87.     // Boost  
  88.     public static Mat MyBoost(Mat trainingData, Mat labels, Mat testData) {  
  89.         Boost boost = Boost.create();  
  90.         // boost.setBoostType(Boost.DISCRETE);  
  91.         boost.setBoostType(Boost.GENTLE);  
  92.         boost.setWeakCount(2);  
  93.         boost.setWeightTrimRate(0.95);  
  94.         boost.setMaxDepth(2);  
  95.         boost.setUseSurrogates(false);  
  96.         boost.setPriors(new Mat());  
  97.   
  98.         TrainData td = TrainData.create(trainingData, Ml.ROW_SAMPLE, labels);  
  99.         boolean success = boost.train(td.getSamples(), Ml.ROW_SAMPLE, td.getResponses());  
  100.         System.out.println(“Boost training result: “ + success);  
  101.         // boost.save(“D:/bp.xml”);//存储模型  
  102.   
  103.         Mat responseMat = new Mat();  
  104.         float response = boost.predict(testData, responseMat, 0);  
  105.         System.out.println(“Boost responseMat:\n” + responseMat.dump());  
  106.         for (int i = 0; i < responseMat.height(); i++) {  
  107.             if (responseMat.get(i, 0)[0] == 0)  
  108.                 System.out.println(“Boy\n”);  
  109.             if (responseMat.get(i, 0)[0] == 1)  
  110.                 System.out.println(“Girl\n”);  
  111.         }  
  112.         return responseMat;  
  113.     }  
  114.   
  115.     // 决策树  
  116.     public static Mat MyDtrees(Mat trainingData, Mat labels, Mat testData) {  
  117.         DTrees dtree = DTrees.create(); // 创建分类器  
  118.         dtree.setMaxDepth(8); // 设置最大深度  
  119.         dtree.setMinSampleCount(2);  
  120.         dtree.setUseSurrogates(false);  
  121.         dtree.setCVFolds(0); // 交叉验证  
  122.         dtree.setUse1SERule(false);  
  123.         dtree.setTruncatePrunedTree(false);  
  124.   
  125.         TrainData td = TrainData.create(trainingData, Ml.ROW_SAMPLE, labels);  
  126.         boolean success = dtree.train(td.getSamples(), Ml.ROW_SAMPLE, td.getResponses());  
  127.         System.out.println(“Dtrees training result: “ + success);  
  128.         // dtree.save(“D:/bp.xml”);//存储模型  
  129.   
  130.         Mat responseMat = new Mat();  
  131.         float response = dtree.predict(testData, responseMat, 0);  
  132.         System.out.println(“Dtrees responseMat:\n” + responseMat.dump());  
  133.         for (int i = 0; i < responseMat.height(); i++) {  
  134.             if (responseMat.get(i, 0)[0] == 0)  
  135.                 System.out.println(“Boy\n”);  
  136.             if (responseMat.get(i, 0)[0] == 1)  
  137.                 System.out.println(“Girl\n”);  
  138.         }  
  139.         return responseMat;  
  140.     }  
  141.   
  142.     // K最邻近  
  143.     public static Mat MyKnn(Mat trainingData, Mat labels, Mat testData) {  
  144.         final int K = 2;  
  145.         TrainData td = TrainData.create(trainingData, Ml.ROW_SAMPLE, labels);  
  146.         KNearest knn = KNearest.create();  
  147.         boolean success = knn.train(trainingData, Ml.ROW_SAMPLE, labels);  
  148.         System.out.println(“Knn training result: “ + success);  
  149.         // knn.save(“D:/bp.xml”);//存储模型  
  150.   
  151.         // find the nearest neighbours of test data  
  152.         Mat results = new Mat();  
  153.         Mat neighborResponses = new Mat();  
  154.         Mat dists = new Mat();  
  155.         knn.findNearest(testData, K, results, neighborResponses, dists);  
  156.         System.out.println(“results:\n” + results.dump());  
  157.         System.out.println(“Knn neighborResponses:\n” + neighborResponses.dump());  
  158.         System.out.println(“dists:\n” + dists.dump());  
  159.         for (int i = 0; i < results.height(); i++) {  
  160.             if (results.get(i, 0)[0] == 0)  
  161.                 System.out.println(“Boy\n”);  
  162.             if (results.get(i, 0)[0] == 1)  
  163.                 System.out.println(“Girl\n”);  
  164.         }  
  165.   
  166.         return results;  
  167.     }  
  168.   
  169.     // 逻辑回归  
  170.     public static Mat MyLogisticRegression(Mat trainingData, Mat labels, Mat testData) {  
  171.         LogisticRegression lr = LogisticRegression.create();  
  172.   
  173.         TrainData td = TrainData.create(trainingData, Ml.ROW_SAMPLE, labels);  
  174.         boolean success = lr.train(td.getSamples(), Ml.ROW_SAMPLE, td.getResponses());  
  175.         System.out.println(“LogisticRegression training result: “ + success);  
  176.         // lr.save(“D:/bp.xml”);//存储模型  
  177.   
  178.         Mat responseMat = new Mat();  
  179.         float response = lr.predict(testData, responseMat, 0);  
  180.         System.out.println(“LogisticRegression responseMat:\n” + responseMat.dump());  
  181.         for (int i = 0; i < responseMat.height(); i++) {  
  182.             if (responseMat.get(i, 0)[0] == 0)  
  183.                 System.out.println(“Boy\n”);  
  184.             if (responseMat.get(i, 0)[0] == 1)  
  185.                 System.out.println(“Girl\n”);  
  186.         }  
  187.         return responseMat;  
  188.     }  
  189.   
  190.     // 贝叶斯  
  191.     public static Mat MyNormalBayes(Mat trainingData, Mat labels, Mat testData) {  
  192.         NormalBayesClassifier nb = NormalBayesClassifier.create();  
  193.   
  194.         TrainData td = TrainData.create(trainingData, Ml.ROW_SAMPLE, labels);  
  195.         boolean success = nb.train(td.getSamples(), Ml.ROW_SAMPLE, td.getResponses());  
  196.         System.out.println(“NormalBayes training result: “ + success);  
  197.         // nb.save(“D:/bp.xml”);//存储模型  
  198.   
  199.         Mat responseMat = new Mat();  
  200.         float response = nb.predict(testData, responseMat, 0);  
  201.         System.out.println(“NormalBayes responseMat:\n” + responseMat.dump());  
  202.         for (int i = 0; i < responseMat.height(); i++) {  
  203.             if (responseMat.get(i, 0)[0] == 0)  
  204.                 System.out.println(“Boy\n”);  
  205.             if (responseMat.get(i, 0)[0] == 1)  
  206.                 System.out.println(“Girl\n”);  
  207.         }  
  208.         return responseMat;  
  209.     }  
  210.   
  211.     // 随机森林  
  212.     public static Mat MyRTrees(Mat trainingData, Mat labels, Mat testData) {  
  213.         RTrees rtrees = RTrees.create();  
  214.         rtrees.setMaxDepth(4);  
  215.         rtrees.setMinSampleCount(2);  
  216.         rtrees.setRegressionAccuracy(0.f);  
  217.         rtrees.setUseSurrogates(false);  
  218.         rtrees.setMaxCategories(16);  
  219.         rtrees.setPriors(new Mat());  
  220.         rtrees.setCalculateVarImportance(false);  
  221.         rtrees.setActiveVarCount(1);  
  222.         rtrees.setTermCriteria(new TermCriteria(TermCriteria.MAX_ITER, 50));  
  223.         TrainData tData = TrainData.create(trainingData, Ml.ROW_SAMPLE, labels);  
  224.         boolean success = rtrees.train(tData.getSamples(), Ml.ROW_SAMPLE, tData.getResponses());  
  225.         System.out.println(“Rtrees training result: “ + success);  
  226.         // rtrees.save(“D:/bp.xml”);//存储模型  
  227.   
  228.         Mat responseMat = new Mat();  
  229.         rtrees.predict(testData, responseMat, 0);  
  230.         System.out.println(“Rtrees responseMat:\n” + responseMat.dump());  
  231.         for (int i = 0; i < responseMat.height(); i++) {  
  232.             if (responseMat.get(i, 0)[0] == 0)  
  233.                 System.out.println(“Boy\n”);  
  234.             if (responseMat.get(i, 0)[0] == 1)  
  235.                 System.out.println(“Girl\n”);  
  236.         }  
  237.         return responseMat;  
  238.     }  
  239.   
  240.     // 支持向量机  
  241.     public static Mat MySvm(Mat trainingData, Mat labels, Mat testData) {  
  242.         SVM svm = SVM.create();  
  243.         svm.setKernel(SVM.LINEAR);  
  244.         svm.setType(SVM.C_SVC);  
  245.         TermCriteria criteria = new TermCriteria(TermCriteria.EPS + TermCriteria.MAX_ITER, 10000);  
  246.         svm.setTermCriteria(criteria);  
  247.         svm.setGamma(0.5);  
  248.         svm.setNu(0.5);  
  249.         svm.setC(1);  
  250.   
  251.         TrainData td = TrainData.create(trainingData, Ml.ROW_SAMPLE, labels);  
  252.         boolean success = svm.train(td.getSamples(), Ml.ROW_SAMPLE, td.getResponses());  
  253.         System.out.println(“Svm training result: “ + success);  
  254.         // svm.save(“D:/bp.xml”);//存储模型  
  255.         // svm.load(“D:/bp.xml”);//读取模型  
  256.   
  257.         Mat responseMat = new Mat();  
  258.         svm.predict(testData, responseMat, 0);  
  259.         System.out.println(“SVM responseMat:\n” + responseMat.dump());  
  260.         for (int i = 0; i < responseMat.height(); i++) {  
  261.             if (responseMat.get(i, 0)[0] == 0)  
  262.                 System.out.println(“Boy\n”);  
  263.             if (responseMat.get(i, 0)[0] == 1)  
  264.                 System.out.println(“Girl\n”);  
  265.         }  
  266.         return responseMat;  
  267.     }  
  268.   
  269.     // SGD支持向量机  
  270.     public static Mat MySvmsgd(Mat trainingData, Mat labels, Mat testData) {  
  271.         SVMSGD Svmsgd = SVMSGD.create();  
  272.         TermCriteria criteria = new TermCriteria(TermCriteria.EPS + TermCriteria.MAX_ITER, 10000);  
  273.         Svmsgd.setTermCriteria(criteria);  
  274.         Svmsgd.setInitialStepSize(2);  
  275.         Svmsgd.setSvmsgdType(SVMSGD.SGD);  
  276.         Svmsgd.setMarginRegularization(0.5f);  
  277.         boolean success = Svmsgd.train(trainingData, Ml.ROW_SAMPLE, labels);  
  278.         System.out.println(“SVMSGD training result: “ + success);  
  279.         // svm.save(“D:/bp.xml”);//存储模型  
  280.         // svm.load(“D:/bp.xml”);//读取模型  
  281.   
  282.         Mat responseMat = new Mat();  
  283.         Svmsgd.predict(testData, responseMat, 0);  
  284.         System.out.println(“SVMSGD responseMat:\n” + responseMat.dump());  
  285.         for (int i = 0; i < responseMat.height(); i++) {  
  286.             if (responseMat.get(i, 0)[0] == 0)  
  287.                 System.out.println(“Boy\n”);  
  288.             if (responseMat.get(i, 0)[0] == 1)  
  289.                 System.out.println(“Girl\n”);  
  290.         }  
  291.         return responseMat;  
  292.     }  
  293. }  

输出结果:

[plain] 
view plain
 copy

  1. Ann training result: true  
  2. Ann responseMat:  
  3. [0.014712702, 0.01492399;  
  4.  0.98786205, 0.987822]  
  5. Boy  
  6.   
  7. Girl  
  8.   
  9. Boost training result: true  
  10. Boost responseMat:  
  11. [0;  
  12.  0]  
  13. Boy  
  14.   
  15. Boy  
  16.   
  17. Dtrees training result: true  
  18. Dtrees responseMat:  
  19. [0;  
  20.  1]  
  21. Boy  
  22.   
  23. Girl  
  24.   
  25. Knn training result: true  
  26. results:  
  27. [0;  
  28.  1]  
  29. Knn neighborResponses:  
  30. [0, 0;  
  31.  1, 1]  
  32. dists:  
  33. [5, 5;  
  34.  1, 8]  
  35. Boy  
  36.   
  37. Girl  
  38.   
  39. LogisticRegression training result: true  
  40. LogisticRegression responseMat:  
  41. [0;  
  42.  1]  
  43. Boy  
  44.   
  45. Girl  
  46.   
  47. NormalBayes training result: true  
  48. NormalBayes responseMat:  
  49. [0;  
  50.  1]  
  51. Boy  
  52.   
  53. Girl  
  54.   
  55. Rtrees training result: true  
  56. Rtrees responseMat:  
  57. [0;  
  58.  1]  
  59. Boy  
  60.   
  61. Girl  
  62.   
  63. Svm training result: true  
  64. SVM responseMat:  
  65. [0;  
  66.  1]  
  67. Boy  
  68.   
  69. Girl  
  70.   
  71. SVMSGD training result: true  
  72. SVMSGD responseMat:  
  73. [1;  
  74.  1]  
  75. Girl  
  76.   
  77. Girl  
点赞