[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]

rf_split.hxx VIGRA

1 /************************************************************************/
2 /* */
3 /* Copyright 2008-2009 by Ullrich Koethe and Rahul Nair */
4 /* */
5 /* This file is part of the VIGRA computer vision library. */
6 /* The VIGRA Website is */
7 /* http://hci.iwr.uni-heidelberg.de/vigra/ */
8 /* Please direct questions, bug reports, and contributions to */
9 /* ullrich.koethe@iwr.uni-heidelberg.de or */
10 /* vigra@informatik.uni-hamburg.de */
11 /* */
12 /* Permission is hereby granted, free of charge, to any person */
13 /* obtaining a copy of this software and associated documentation */
14 /* files (the "Software"), to deal in the Software without */
15 /* restriction, including without limitation the rights to use, */
16 /* copy, modify, merge, publish, distribute, sublicense, and/or */
17 /* sell copies of the Software, and to permit persons to whom the */
18 /* Software is furnished to do so, subject to the following */
19 /* conditions: */
20 /* */
21 /* The above copyright notice and this permission notice shall be */
22 /* included in all copies or substantial portions of the */
23 /* Software. */
24 /* */
25 /* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND */
26 /* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES */
27 /* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND */
28 /* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT */
29 /* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, */
30 /* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING */
31 /* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR */
32 /* OTHER DEALINGS IN THE SOFTWARE. */
33 /* */
34 /************************************************************************/
35 #ifndef VIGRA_RANDOM_FOREST_SPLIT_HXX
36 #define VIGRA_RANDOM_FOREST_SPLIT_HXX
37 #include <algorithm>
38 #include <cstddef>
39 #include <map>
40 #include <numeric>
41 #include <math.h>
42 #include "../mathutil.hxx"
43 #include "../array_vector.hxx"
44 #include "../sized_int.hxx"
45 #include "../matrix.hxx"
46 #include "../random.hxx"
47 #include "../functorexpression.hxx"
48 #include "rf_nodeproxy.hxx"
49 //#include "rf_sampling.hxx"
50 #include "rf_region.hxx"
51 //#include "../hokashyap.hxx"
52 //#include "vigra/rf_helpers.hxx"
53 
54 namespace vigra
55 {
56 
57 // Incomplete Class to ensure that findBestSplit is always implemented in
58 // the derived classes of SplitBase
59 class CompileTimeError;
60 
61 
62 namespace detail
63 {
64  template<class Tag>
65  class Normalise
66  {
67  public:
68  template<class Iter>
69  static void exec(Iter /*begin*/, Iter /*end*/)
70  {}
71  };
72 
73  template<>
74  class Normalise<ClassificationTag>
75  {
76  public:
77  template<class Iter>
78  static void exec (Iter begin, Iter end)
79  {
80  double bla = std::accumulate(begin, end, 0.0);
81  for(int ii = 0; ii < end - begin; ++ii)
82  begin[ii] = begin[ii]/bla ;
83  }
84  };
85 }
86 
87 
88 /** Base Class for all SplitFunctors used with the \ref RandomForest class
89  defines the interface used while learning a tree.
90 **/
91 template<class Tag>
92 class SplitBase
93 {
94  public:
95 
96  typedef Tag RF_Tag;
99 
100  ProblemSpec<> ext_param_;
101 
104 
105  NodeBase node_;
106 
107  /** returns the DecisionTree Node created by
108  \ref SplitBase::findBestSplit() or \ref SplitBase::makeTerminalNode().
109  **/
110 
111  template<class T>
113  {
114  ext_param_ = in;
115  t_data.push_back(in.column_count_);
116  t_data.push_back(in.class_count_);
117  }
118 
119  NodeBase & createNode()
120  {
121  return node_;
122  }
123 
124  int classCount() const
125  {
126  return int(t_data[1]);
127  }
128 
129  int featureCount() const
130  {
131  return int(t_data[0]);
132  }
133 
134  /** resets internal data. Should always be called before
135  calling findBestSplit or makeTerminalNode
136  **/
137  void reset()
138  {
139  t_data.resize(2);
140  p_data.resize(0);
141  }
142 
143 
144  /** findBestSplit has to be re-implemented in derived split functor.
145  The defaut implementation only insures that a CompileTime error is issued
146  if no such method was defined.
147  **/
148 
149  template<class T, class C, class T2, class C2, class Region, class Random>
151  MultiArrayView<2, T2, C2> /*labels*/,
152  Region /*region*/,
153  ArrayVector<Region> /*childs*/,
154  Random /*randint*/)
155  {
156 #ifndef __clang__
157  // FIXME: This compile-time checking trick does not work for clang.
158  CompileTimeError SplitFunctor__findBestSplit_member_was_not_defined;
159 #endif
160  return 0;
161  }
162 
163  /** Default action for creating a terminal Node.
164  sets the Class probability of the remaining region according to
165  the class histogram
166  **/
167  template<class T, class C, class T2,class C2, class Region, class Random>
169  MultiArrayView<2, T2, C2> /* labels */,
170  Region & region,
171  Random /* randint */)
172  {
173  Node<e_ConstProbNode> ret(t_data, p_data);
174  node_ = ret;
175  if(ext_param_.class_weights_.size() != region.classCounts().size())
176  {
177  std::copy(region.classCounts().begin(),
178  region.classCounts().end(),
179  ret.prob_begin());
180  }
181  else
182  {
183  std::transform(region.classCounts().begin(),
184  region.classCounts().end(),
185  ext_param_.class_weights_.begin(),
186  ret.prob_begin(), std::multiplies<double>());
187  }
188  detail::Normalise<RF_Tag>::exec(ret.prob_begin(), ret.prob_end());
189 // std::copy(ret.prob_begin(), ret.prob_end(), std::ostream_iterator<double>(std::cerr, ", " ));
190 // std::cerr << std::endl;
191  ret.weights() = region.size();
192  return e_ConstProbNode;
193  }
194 
195 
196 };
197 
198 /** Functor to sort the indices of a feature Matrix by a certain dimension
199 **/
200 template<class DataMatrix>
202 {
203  DataMatrix const & data_;
204  MultiArrayIndex sortColumn_;
205  double thresVal_;
206  public:
207 
208  SortSamplesByDimensions(DataMatrix const & data,
209  MultiArrayIndex sortColumn,
210  double thresVal = 0.0)
211  : data_(data),
212  sortColumn_(sortColumn),
213  thresVal_(thresVal)
214  {}
215 
216  void setColumn(MultiArrayIndex sortColumn)
217  {
218  sortColumn_ = sortColumn;
219  }
220  void setThreshold(double value)
221  {
222  thresVal_ = value;
223  }
224 
225  bool operator()(MultiArrayIndex l, MultiArrayIndex r) const
226  {
227  return data_(l, sortColumn_) < data_(r, sortColumn_);
228  }
229  bool operator()(MultiArrayIndex l) const
230  {
231  return data_(l, sortColumn_) < thresVal_;
232  }
233 };
234 
235 template<class DataMatrix>
236 class DimensionNotEqual
237 {
238  DataMatrix const & data_;
239  MultiArrayIndex sortColumn_;
240 
241  public:
242 
243  DimensionNotEqual(DataMatrix const & data,
244  MultiArrayIndex sortColumn)
245  : data_(data),
246  sortColumn_(sortColumn)
247  {}
248 
249  void setColumn(MultiArrayIndex sortColumn)
250  {
251  sortColumn_ = sortColumn;
252  }
253 
254  bool operator()(MultiArrayIndex l, MultiArrayIndex r) const
255  {
256  return data_(l, sortColumn_) != data_(r, sortColumn_);
257  }
258 };
259 
260 template<class DataMatrix>
261 class SortSamplesByHyperplane
262 {
263  DataMatrix const & data_;
264  Node<i_HyperplaneNode> const & node_;
265 
266  public:
267 
268  SortSamplesByHyperplane(DataMatrix const & data,
269  Node<i_HyperplaneNode> const & node)
270  :
271  data_(data),
272  node_(node)
273  {}
274 
275  /** calculate the distance of a sample point to a hyperplane
276  */
277  double operator[](MultiArrayIndex l) const
278  {
279  double result_l = -1 * node_.intercept();
280  for(int ii = 0; ii < node_.columns_size(); ++ii)
281  {
282  result_l += rowVector(data_, l)[node_.columns_begin()[ii]]
283  * node_.weights()[ii];
284  }
285  return result_l;
286  }
287 
288  bool operator()(MultiArrayIndex l, MultiArrayIndex r) const
289  {
290  return (*this)[l] < (*this)[r];
291  }
292 
293 };
294 
295 /** makes a Class Histogram given indices in a labels_ array
296  * usage:
297  * MultiArrayView<2, T2, C2> labels = makeSomeLabels()
298  * ArrayVector<int> hist(numberOfLabels(labels), 0);
299  * RandomForestClassCounter<T2, C2, ArrayVector> counter(labels, hist);
300  *
301  * Container<int> indices = getSomeIndices()
302  * std::for_each(indices, counter);
303  */
304 template <class DataSource, class CountArray>
306 {
307  DataSource const & labels_;
308  CountArray & counts_;
309 
310  public:
311 
312  RandomForestClassCounter(DataSource const & labels,
313  CountArray & counts)
314  : labels_(labels),
315  counts_(counts)
316  {
317  reset();
318  }
319 
320  void reset()
321  {
322  counts_.init(0);
323  }
324 
325  void operator()(MultiArrayIndex l) const
326  {
327  counts_[labels_[l]] +=1;
328  }
329 };
330 
331 
332 /** Functor To Calculate the Best possible Split Based on the Gini Index
333  given Labels and Features along a given Axis
334 */
335 
336 namespace detail
337 {
338  template<int N>
339  class ConstArr
340  {
341  public:
342  double operator[](size_t) const
343  {
344  return (double)N;
345  }
346  };
347 
348 
349 }
350 
351 
352 
353 
354 /** Functor to calculate the entropy based impurity
355  */
357 {
358 public:
359  /**calculate the weighted gini impurity based on class histogram
360  * and class weights
361  */
362  template<class Array, class Array2>
363  double operator() (Array const & hist,
364  Array2 const & weights,
365  double total = 1.0) const
366  {
367  return impurity(hist, weights, total);
368  }
369 
370  /** calculate the gini based impurity based on class histogram
371  */
372  template<class Array>
373  double operator()(Array const & hist, double total = 1.0) const
374  {
375  return impurity(hist, total);
376  }
377 
378  /** static version of operator(hist total)
379  */
380  template<class Array>
381  static double impurity(Array const & hist, double total)
382  {
383  return impurity(hist, detail::ConstArr<1>(), total);
384  }
385 
386  /** static version of operator(hist, weights, total)
387  */
388  template<class Array, class Array2>
389  static double impurity (Array const & hist,
390  Array2 const & weights,
391  double total)
392  {
393 
394  int class_count = hist.size();
395  double entropy = 0.0;
396  if(class_count == 2)
397  {
398  double p0 = (hist[0]/total);
399  double p1 = (hist[1]/total);
400  entropy = 0 - weights[0]*p0*std::log(p0) - weights[1]*p1*std::log(p1);
401  }
402  else
403  {
404  for(int ii = 0; ii < class_count; ++ii)
405  {
406  double w = weights[ii];
407  double pii = hist[ii]/total;
408  entropy -= w*( pii*std::log(pii));
409  }
410  }
411  entropy = total * entropy;
412  return entropy;
413  }
414 };
415 
416 /** Functor to calculate the gini impurity
417  */
419 {
420 public:
421  /**calculate the weighted gini impurity based on class histogram
422  * and class weights
423  */
424  template<class Array, class Array2>
425  double operator() (Array const & hist,
426  Array2 const & weights,
427  double total = 1.0) const
428  {
429  return impurity(hist, weights, total);
430  }
431 
432  /** calculate the gini based impurity based on class histogram
433  */
434  template<class Array>
435  double operator()(Array const & hist, double total = 1.0) const
436  {
437  return impurity(hist, total);
438  }
439 
440  /** static version of operator(hist total)
441  */
442  template<class Array>
443  static double impurity(Array const & hist, double total)
444  {
445  return impurity(hist, detail::ConstArr<1>(), total);
446  }
447 
448  /** static version of operator(hist, weights, total)
449  */
450  template<class Array, class Array2>
451  static double impurity (Array const & hist,
452  Array2 const & weights,
453  double total)
454  {
455 
456  int class_count = hist.size();
457  double gini = 0.0;
458  if(class_count == 2)
459  {
460  double w = weights[0] * weights[1];
461  gini = w * (hist[0] * hist[1] / total);
462  }
463  else
464  {
465  for(int ii = 0; ii < class_count; ++ii)
466  {
467  double w = weights[ii];
468  gini += w*( hist[ii]*( 1.0 - w * hist[ii]/total ) );
469  }
470  }
471  return gini;
472  }
473 };
474 
475 
476 template <class DataSource, class Impurity= GiniCriterion>
477 class ImpurityLoss
478 {
479 
480  DataSource const & labels_;
481  ArrayVector<double> counts_;
482  ArrayVector<double> const class_weights_;
483  double total_counts_;
484  Impurity impurity_;
485 
486  public:
487 
488  template<class T>
489  ImpurityLoss(DataSource const & labels,
490  ProblemSpec<T> const & ext_)
491  : labels_(labels),
492  counts_(ext_.class_count_, 0.0),
493  class_weights_(ext_.class_weights_),
494  total_counts_(0.0)
495  {}
496 
497  void reset()
498  {
499  counts_.init(0);
500  total_counts_ = 0.0;
501  }
502 
503  template<class Counts>
504  double increment_histogram(Counts const & counts)
505  {
506  std::transform(counts.begin(), counts.end(),
507  counts_.begin(), counts_.begin(),
508  std::plus<double>());
509  total_counts_ = std::accumulate( counts_.begin(),
510  counts_.end(),
511  0.0);
512  return impurity_(counts_, class_weights_, total_counts_);
513  }
514 
515  template<class Counts>
516  double decrement_histogram(Counts const & counts)
517  {
518  std::transform(counts.begin(), counts.end(),
519  counts_.begin(), counts_.begin(),
520  std::minus<double>());
521  total_counts_ = std::accumulate( counts_.begin(),
522  counts_.end(),
523  0.0);
524  return impurity_(counts_, class_weights_, total_counts_);
525  }
526 
527  template<class Iter>
528  double increment(Iter begin, Iter end)
529  {
530  for(Iter iter = begin; iter != end; ++iter)
531  {
532  counts_[labels_(*iter, 0)] +=1.0;
533  total_counts_ +=1.0;
534  }
535  return impurity_(counts_, class_weights_, total_counts_);
536  }
537 
538  template<class Iter>
539  double decrement(Iter const & begin, Iter const & end)
540  {
541  for(Iter iter = begin; iter != end; ++iter)
542  {
543  counts_[labels_(*iter,0)] -=1.0;
544  total_counts_ -=1.0;
545  }
546  return impurity_(counts_, class_weights_, total_counts_);
547  }
548 
549  template<class Iter, class Resp_t>
550  double init (Iter /*begin*/, Iter /*end*/, Resp_t resp)
551  {
552  reset();
553  std::copy(resp.begin(), resp.end(), counts_.begin());
554  total_counts_ = std::accumulate(counts_.begin(), counts_.end(), 0.0);
555  return impurity_(counts_,class_weights_, total_counts_);
556  }
557 
558  ArrayVector<double> const & response()
559  {
560  return counts_;
561  }
562 };
563 
564 
565 
566  template <class DataSource>
567  class RegressionForestCounter
568  {
569  public:
570  typedef MultiArrayShape<2>::type Shp;
571  DataSource const & labels_;
572  ArrayVector <double> mean_;
573  ArrayVector <double> variance_;
574  ArrayVector <double> tmp_;
575  size_t count_;
576  int* end_;
577 
578  template<class T>
579  RegressionForestCounter(DataSource const & labels,
580  ProblemSpec<T> const & ext_)
581  :
582  labels_(labels),
583  mean_(ext_.response_size_, 0.0),
584  variance_(ext_.response_size_, 0.0),
585  tmp_(ext_.response_size_),
586  count_(0)
587  {}
588 
589  template<class Iter>
590  double increment (Iter begin, Iter end)
591  {
592  for(Iter iter = begin; iter != end; ++iter)
593  {
594  ++count_;
595  for(unsigned int ii = 0; ii < mean_.size(); ++ii)
596  tmp_[ii] = labels_(*iter, ii) - mean_[ii];
597  double f = 1.0 / count_,
598  f1 = 1.0 - f;
599  for(unsigned int ii = 0; ii < mean_.size(); ++ii)
600  mean_[ii] += f*tmp_[ii];
601  for(unsigned int ii = 0; ii < mean_.size(); ++ii)
602  variance_[ii] += f1*sq(tmp_[ii]);
603  }
604  double res = std::accumulate(variance_.begin(),
605  variance_.end(),
606  0.0,
607  std::plus<double>());
608  //std::cerr << res << " ) = ";
609  return res;
610  }
611 
612  template<class Iter> //This is BROKEN
613  double decrement (Iter begin, Iter end)
614  {
615  for(Iter iter = begin; iter != end; ++iter)
616  {
617  --count_;
618  }
619 
620  begin = end;
621  end = end + count_;
622 
623 
624  for(unsigned int ii = 0; ii < mean_.size(); ++ii)
625  {
626  mean_[ii] = 0;
627  for(Iter iter = begin; iter != end; ++iter)
628  {
629  mean_[ii] += labels_(*iter, ii);
630  }
631  mean_[ii] /= count_;
632  variance_[ii] = 0;
633  for(Iter iter = begin; iter != end; ++iter)
634  {
635  variance_[ii] += (labels_(*iter, ii) - mean_[ii])*(labels_(*iter, ii) - mean_[ii]);
636  }
637  }
638  double res = std::accumulate(variance_.begin(),
639  variance_.end(),
640  0.0,
641  std::plus<double>());
642  //std::cerr << res << " ) = ";
643  return res;
644  }
645 
646 
647  template<class Iter, class Resp_t>
648  double init (Iter begin, Iter end, Resp_t /*resp*/)
649  {
650  reset();
651  return this->increment(begin, end);
652 
653  }
654 
655 
656  ArrayVector<double> const & response()
657  {
658  return mean_;
659  }
660 
661  void reset()
662  {
663  mean_.init(0.0);
664  variance_.init(0.0);
665  count_ = 0;
666  }
667  };
668 
669 
670 template <class DataSource>
671 class RegressionForestCounter2
672 {
673 public:
674  typedef MultiArrayShape<2>::type Shp;
675  DataSource const & labels_;
676  ArrayVector <double> mean_;
677  ArrayVector <double> variance_;
678  ArrayVector <double> tmp_;
679  size_t count_;
680 
681  template<class T>
682  RegressionForestCounter2(DataSource const & labels,
683  ProblemSpec<T> const & ext_)
684  :
685  labels_(labels),
686  mean_(ext_.response_size_, 0.0),
687  variance_(ext_.response_size_, 0.0),
688  tmp_(ext_.response_size_),
689  count_(0)
690  {}
691 
692  template<class Iter>
693  double increment (Iter begin, Iter end)
694  {
695  for(Iter iter = begin; iter != end; ++iter)
696  {
697  ++count_;
698  for(int ii = 0; ii < mean_.size(); ++ii)
699  tmp_[ii] = labels_(*iter, ii) - mean_[ii];
700  double f = 1.0 / count_,
701  f1 = 1.0 - f;
702  for(int ii = 0; ii < mean_.size(); ++ii)
703  mean_[ii] += f*tmp_[ii];
704  for(int ii = 0; ii < mean_.size(); ++ii)
705  variance_[ii] += f1*sq(tmp_[ii]);
706  }
707  double res = std::accumulate(variance_.begin(),
708  variance_.end(),
709  0.0,
710  std::plus<double>())
711  /((count_ == 1)? 1:(count_ -1));
712  //std::cerr << res << " ) = ";
713  return res;
714  }
715 
716  template<class Iter> //This is BROKEN
717  double decrement (Iter begin, Iter end)
718  {
719  for(Iter iter = begin; iter != end; ++iter)
720  {
721  double f = 1.0 / count_,
722  f1 = 1.0 - f;
723  for(int ii = 0; ii < mean_.size(); ++ii)
724  mean_[ii] = (mean_[ii] - f*labels_(*iter,ii))/(1-f);
725  for(int ii = 0; ii < mean_.size(); ++ii)
726  variance_[ii] -= f1*sq(labels_(*iter,ii) - mean_[ii]);
727  --count_;
728  }
729  double res = std::accumulate(variance_.begin(),
730  variance_.end(),
731  0.0,
732  std::plus<double>())
733  /((count_ == 1)? 1:(count_ -1));
734  //std::cerr << "( " << res << " + ";
735  return res;
736  }
737  /* west's algorithm for incremental variance
738  // calculation
739  template<class Iter>
740  double increment (Iter begin, Iter end)
741  {
742  for(Iter iter = begin; iter != end; ++iter)
743  {
744  ++count_;
745  for(int ii = 0; ii < mean_.size(); ++ii)
746  tmp_[ii] = labels_(*iter, ii) - mean_[ii];
747  double f = 1.0 / count_,
748  f1 = 1.0 - f;
749  for(int ii = 0; ii < mean_.size(); ++ii)
750  mean_[ii] += f*tmp_[ii];
751  for(int ii = 0; ii < mean_.size(); ++ii)
752  variance_[ii] += f1*sq(tmp_[ii]);
753  }
754  return std::accumulate(variance_.begin(),
755  variance_.end(),
756  0.0,
757  std::plus<double>())
758  /(count_ -1);
759  }
760 
761  template<class Iter>
762  double decrement (Iter begin, Iter end)
763  {
764  for(Iter iter = begin; iter != end; ++iter)
765  {
766  --count_;
767  for(int ii = 0; ii < mean_.size(); ++ii)
768  tmp_[ii] = labels_(*iter, ii) - mean_[ii];
769  double f = 1.0 / count_,
770  f1 = 1.0 + f;
771  for(int ii = 0; ii < mean_.size(); ++ii)
772  mean_[ii] -= f*tmp_[ii];
773  for(int ii = 0; ii < mean_.size(); ++ii)
774  variance_[ii] -= f1*sq(tmp_[ii]);
775  }
776  return std::accumulate(variance_.begin(),
777  variance_.end(),
778  0.0,
779  std::plus<double>())
780  /(count_ -1);
781  }*/
782 
783  template<class Iter, class Resp_t>
784  double init (Iter begin, Iter end, Resp_t resp)
785  {
786  reset();
787  return this->increment(begin, end, resp);
788  }
789 
790 
791  ArrayVector<double> const & response()
792  {
793  return mean_;
794  }
795 
796  void reset()
797  {
798  mean_.init(0.0);
799  variance_.init(0.0);
800  count_ = 0;
801  }
802 };
803 
804 template<class Tag, class Datatyp>
805 struct LossTraits;
806 
807 struct LSQLoss
808 {};
809 
810 template<class Datatype>
811 struct LossTraits<GiniCriterion, Datatype>
812 {
813  typedef ImpurityLoss<Datatype, GiniCriterion> type;
814 };
815 
816 template<class Datatype>
817 struct LossTraits<EntropyCriterion, Datatype>
818 {
819  typedef ImpurityLoss<Datatype, EntropyCriterion> type;
820 };
821 
822 template<class Datatype>
823 struct LossTraits<LSQLoss, Datatype>
824 {
825  typedef RegressionForestCounter<Datatype> type;
826 };
827 
828 /** Given a column, choose a split that minimizes some loss
829  */
830 template<class LineSearchLossTag>
832 {
833 public:
834  ArrayVector<double> class_weights_;
835  ArrayVector<double> bestCurrentCounts[2];
836  double min_gini_;
837  std::ptrdiff_t min_index_;
838  double min_threshold_;
839  ProblemSpec<> ext_param_;
840 
842  {}
843 
844  template<class T>
845  BestGiniOfColumn(ProblemSpec<T> const & ext)
846  :
847  class_weights_(ext.class_weights_),
848  ext_param_(ext)
849  {
850  bestCurrentCounts[0].resize(ext.class_count_);
851  bestCurrentCounts[1].resize(ext.class_count_);
852  }
853  template<class T>
854  void set_external_parameters(ProblemSpec<T> const & ext)
855  {
856  class_weights_ = ext.class_weights_;
857  ext_param_ = ext;
858  bestCurrentCounts[0].resize(ext.class_count_);
859  bestCurrentCounts[1].resize(ext.class_count_);
860  }
861  /** calculate the best gini split along a Feature Column
862  * \param column the feature vector - has to support the [] operator
863  * \param labels the label vector
864  * \param begin
865  * \param end (in and out)
866  * begin and end iterators to the indices of the
867  * samples in the current region.
868  * the range begin - end is sorted by the column supplied
869  * during function execution.
870  * \param region_response
871  * ???
872  * class histogram of the range.
873  *
874  * precondition: begin, end valid range,
875  * class_counts positive integer valued array with the
876  * class counts in the current range.
877  * labels.size() >= max(begin, end);
878  * postcondition:
879  * begin, end sorted by column given.
880  * min_gini_ contains the minimum gini found or
881  * NumericTraits<double>::max if no split was found.
882  * min_index_ contains the splitting index in the range
883  * or invalid data if no split was found.
884  * BestCirremtcounts[0] and [1] contain the
885  * class histogram of the left and right region of
886  * the left and right regions.
887  */
888  template< class DataSourceF_t,
889  class DataSource_t,
890  class I_Iter,
891  class Array>
892  void operator()(DataSourceF_t const & column,
893  DataSource_t const & labels,
894  I_Iter & begin,
895  I_Iter & end,
896  Array const & region_response)
897  {
898  std::sort(begin, end,
900  typedef typename
901  LossTraits<LineSearchLossTag, DataSource_t>::type LineSearchLoss;
902  LineSearchLoss left(labels, ext_param_); //initialize left and right region
903  LineSearchLoss right(labels, ext_param_);
904 
905 
906 
907  min_gini_ = right.init(begin, end, region_response);
908  min_threshold_ = *begin;
909  min_index_ = 0; //the starting point where to split
910  DimensionNotEqual<DataSourceF_t> comp(column, 0);
911 
912  I_Iter iter = begin;
913  I_Iter next = std::adjacent_find(iter, end, comp);
914  //std::cerr << std::distance(begin, end) << std::endl;
915  while( next != end)
916  {
917  double lr = right.decrement(iter, next + 1);
918  double ll = left.increment(iter , next + 1);
919  double loss = lr +ll;
920  //std::cerr <<lr << " + "<< ll << " " << loss << " ";
921 #ifdef CLASSIFIER_TEST
922  if(loss < min_gini_ && !closeAtTolerance(loss, min_gini_))
923 #else
924  if(loss < min_gini_ )
925 #endif
926  {
927  bestCurrentCounts[0] = left.response();
928  bestCurrentCounts[1] = right.response();
929 #ifdef CLASSIFIER_TEST
930  min_gini_ = loss < min_gini_? loss : min_gini_;
931 #else
932  min_gini_ = loss;
933 #endif
934  min_index_ = next - begin +1 ;
935  min_threshold_ = (double(column(*next,0)) + double(column(*(next +1), 0)))/2.0;
936  }
937  iter = next +1 ;
938  next = std::adjacent_find(iter, end, comp);
939  }
940  //std::cerr << std::endl << " 000 " << std::endl;
941  //int in;
942  //std::cin >> in;
943  }
944 
945  template<class DataSource_t, class Iter, class Array>
946  double loss_of_region(DataSource_t const & labels,
947  Iter & begin,
948  Iter & end,
949  Array const & region_response) const
950  {
951  typedef typename
952  LossTraits<LineSearchLossTag, DataSource_t>::type LineSearchLoss;
953  LineSearchLoss region_loss(labels, ext_param_);
954  return
955  region_loss.init(begin, end, region_response);
956  }
957 
958 };
959 
960 namespace detail
961 {
962  template<class T>
963  struct Correction
964  {
965  template<class Region, class LabelT>
966  static void exec(Region & /*in*/, LabelT & /*labels*/)
967  {}
968  };
969 
970  template<>
971  struct Correction<ClassificationTag>
972  {
973  template<class Region, class LabelT>
974  static void exec(Region & region, LabelT & labels)
975  {
976  if(std::accumulate(region.classCounts().begin(),
977  region.classCounts().end(), 0.0) != region.size())
978  {
979  RandomForestClassCounter< LabelT,
980  ArrayVector<double> >
981  counter(labels, region.classCounts());
982  std::for_each( region.begin(), region.end(), counter);
983  region.classCountsIsValid = true;
984  }
985  }
986  };
987 }
988 
989 /** Chooses mtry columns and applies ColumnDecisionFunctor to each of the
990  * columns. Then Chooses the column that is best
991  */
992 template<class ColumnDecisionFunctor, class Tag = ClassificationTag>
993 class ThresholdSplit: public SplitBase<Tag>
994 {
995  public:
996 
997 
998  typedef SplitBase<Tag> SB;
999 
1000  ArrayVector<Int32> splitColumns;
1001  ColumnDecisionFunctor bgfunc;
1002 
1003  double region_gini_;
1004  ArrayVector<double> min_gini_;
1005  ArrayVector<std::ptrdiff_t> min_indices_;
1006  ArrayVector<double> min_thresholds_;
1007 
1008  int bestSplitIndex;
1009 
1010  double minGini() const
1011  {
1012  return min_gini_[bestSplitIndex];
1013  }
1014  int bestSplitColumn() const
1015  {
1016  return splitColumns[bestSplitIndex];
1017  }
1018  double bestSplitThreshold() const
1019  {
1020  return min_thresholds_[bestSplitIndex];
1021  }
1022 
1023  template<class T>
1024  void set_external_parameters(ProblemSpec<T> const & in)
1025  {
1027  bgfunc.set_external_parameters( SB::ext_param_);
1028  int featureCount_ = SB::ext_param_.column_count_;
1029  splitColumns.resize(featureCount_);
1030  for(int k=0; k<featureCount_; ++k)
1031  splitColumns[k] = k;
1032  min_gini_.resize(featureCount_);
1033  min_indices_.resize(featureCount_);
1034  min_thresholds_.resize(featureCount_);
1035  }
1036 
1037 
1038  template<class T, class C, class T2, class C2, class Region, class Random>
1039  int findBestSplit(MultiArrayView<2, T, C> features,
1041  Region & region,
1042  ArrayVector<Region>& childRegions,
1043  Random & randint)
1044  {
1045 
1046  typedef typename Region::IndexIterator IndexIterator;
1047  if(region.size() == 0)
1048  {
1049  std::cerr << "SplitFunctor::findBestSplit(): stackentry with 0 examples encountered\n"
1050  "continuing learning process....";
1051  }
1052  // calculate things that haven't been calculated yet.
1053  detail::Correction<Tag>::exec(region, labels);
1054 
1055 
1056  // Is the region pure already?
1057  region_gini_ = bgfunc.loss_of_region(labels,
1058  region.begin(),
1059  region.end(),
1060  region.classCounts());
1061  if(region_gini_ <= SB::ext_param_.precision_)
1062  return this->makeTerminalNode(features, labels, region, randint);
1063 
1064  // select columns to be tried.
1065  for(int ii = 0; ii < SB::ext_param_.actual_mtry_; ++ii)
1066  std::swap(splitColumns[ii],
1067  splitColumns[ii+ randint(features.shape(1) - ii)]);
1068 
1069  // find the best gini index
1070  bestSplitIndex = 0;
1071  double current_min_gini = region_gini_;
1072  int num2try = features.shape(1);
1073  for(int k=0; k<num2try; ++k)
1074  {
1075  //this functor does all the work
1076  bgfunc(columnVector(features, splitColumns[k]),
1077  labels,
1078  region.begin(), region.end(),
1079  region.classCounts());
1080  min_gini_[k] = bgfunc.min_gini_;
1081  min_indices_[k] = bgfunc.min_index_;
1082  min_thresholds_[k] = bgfunc.min_threshold_;
1083 #ifdef CLASSIFIER_TEST
1084  if( bgfunc.min_gini_ < current_min_gini
1085  && !closeAtTolerance(bgfunc.min_gini_, current_min_gini))
1086 #else
1087  if(bgfunc.min_gini_ < current_min_gini)
1088 #endif
1089  {
1090  current_min_gini = bgfunc.min_gini_;
1091  childRegions[0].classCounts() = bgfunc.bestCurrentCounts[0];
1092  childRegions[1].classCounts() = bgfunc.bestCurrentCounts[1];
1093  childRegions[0].classCountsIsValid = true;
1094  childRegions[1].classCountsIsValid = true;
1095 
1096  bestSplitIndex = k;
1097  num2try = SB::ext_param_.actual_mtry_;
1098  }
1099  }
1100  //std::cerr << current_min_gini << "curr " << region_gini_ << std::endl;
1101  // did not find any suitable split
1102  // FIXME: this is wrong: sometimes we must execute bad splits to make progress,
1103  // especially near the root.
1104  if(closeAtTolerance(current_min_gini, region_gini_))
1105  return this->makeTerminalNode(features, labels, region, randint);
1106 
1107  //create a Node for output
1108  Node<i_ThresholdNode> node(SB::t_data, SB::p_data);
1109  SB::node_ = node;
1110  node.threshold() = min_thresholds_[bestSplitIndex];
1111  node.column() = splitColumns[bestSplitIndex];
1112 
1113  // partition the range according to the best dimension
1115  sorter(features, node.column(), node.threshold());
1116  IndexIterator bestSplit =
1117  std::partition(region.begin(), region.end(), sorter);
1118  // Save the ranges of the child stack entries.
1119  childRegions[0].setRange( region.begin() , bestSplit );
1120  childRegions[0].rule = region.rule;
1121  childRegions[0].rule.push_back(std::make_pair(1, 1.0));
1122  childRegions[1].setRange( bestSplit , region.end() );
1123  childRegions[1].rule = region.rule;
1124  childRegions[1].rule.push_back(std::make_pair(1, 1.0));
1125 
1126  return i_ThresholdNode;
1127  }
1128 };
1129 
1133 
1134 namespace rf
1135 {
1136 
1137 /** This namespace contains additional Splitfunctors.
1138  *
1139  * The Split functor classes are designed in a modular fashion because new split functors may
1140  * share a lot of code with existing ones.
1141  *
1142  * ThresholdSplit implements the functionality needed for any split functor, that makes its
1143  * decision via one dimensional axis-parallel cuts. The Template parameter defines how the split
1144  * along one dimension is chosen.
1145  *
1146  * The BestGiniOfColumn class chooses a split that minimizes one of the Loss functions supplied
1147  * (GiniCriterion for classification and LSQLoss for regression). Median chooses the Split in a
1148  * kD tree fashion.
1149  *
1150  *
1151  * Currently defined typedefs:
1152  * \code
1153  * typedef ThresholdSplit<BestGiniOfColumn<GiniCriterion> > GiniSplit;
1154  * typedef ThresholdSplit<BestGiniOfColumn<LSQLoss>, RegressionTag> RegressionSplit;
1155  * typedef ThresholdSplit<Median> MedianSplit;
1156  * \endcode
1157  */
1158 namespace split
1159 {
1160 
1161 /** This Functor chooses the median value of a column
1162  */
1163 class Median
1164 {
1165 public:
1166 
1168  ArrayVector<double> class_weights_;
1169  ArrayVector<double> bestCurrentCounts[2];
1170  double min_gini_;
1171  std::ptrdiff_t min_index_;
1172  double min_threshold_;
1173  ProblemSpec<> ext_param_;
1174 
1175  Median()
1176  {}
1177 
1178  template<class T>
1179  Median(ProblemSpec<T> const & ext)
1180  :
1181  class_weights_(ext.class_weights_),
1182  ext_param_(ext)
1183  {
1184  bestCurrentCounts[0].resize(ext.class_count_);
1185  bestCurrentCounts[1].resize(ext.class_count_);
1186  }
1187 
1188  template<class T>
1189  void set_external_parameters(ProblemSpec<T> const & ext)
1190  {
1191  class_weights_ = ext.class_weights_;
1192  ext_param_ = ext;
1193  bestCurrentCounts[0].resize(ext.class_count_);
1194  bestCurrentCounts[1].resize(ext.class_count_);
1195  }
1196 
1197  template< class DataSourceF_t,
1198  class DataSource_t,
1199  class I_Iter,
1200  class Array>
1201  void operator()(DataSourceF_t const & column,
1202  DataSource_t const & labels,
1203  I_Iter & begin,
1204  I_Iter & end,
1205  Array const & region_response)
1206  {
1207  std::sort(begin, end,
1209  typedef typename
1210  LossTraits<LineSearchLossTag, DataSource_t>::type LineSearchLoss;
1211  LineSearchLoss left(labels, ext_param_);
1212  LineSearchLoss right(labels, ext_param_);
1213  right.init(begin, end, region_response);
1214 
1215  min_gini_ = NumericTraits<double>::max();
1216  min_index_ = floor(double(end - begin)/2.0);
1217  min_threshold_ = column[*(begin + min_index_)];
1219  sorter(column, 0, min_threshold_);
1220  I_Iter part = std::partition(begin, end, sorter);
1221  DimensionNotEqual<DataSourceF_t> comp(column, 0);
1222  if(part == begin)
1223  {
1224  part= std::adjacent_find(part, end, comp)+1;
1225 
1226  }
1227  if(part >= end)
1228  {
1229  return;
1230  }
1231  else
1232  {
1233  min_threshold_ = column[*part];
1234  }
1235  min_gini_ = right.decrement(begin, part)
1236  + left.increment(begin , part);
1237 
1238  bestCurrentCounts[0] = left.response();
1239  bestCurrentCounts[1] = right.response();
1240 
1241  min_index_ = part - begin;
1242  }
1243 
1244  template<class DataSource_t, class Iter, class Array>
1245  double loss_of_region(DataSource_t const & labels,
1246  Iter & begin,
1247  Iter & end,
1248  Array const & region_response) const
1249  {
1250  typedef typename
1251  LossTraits<LineSearchLossTag, DataSource_t>::type LineSearchLoss;
1252  LineSearchLoss region_loss(labels, ext_param_);
1253  return
1254  region_loss.init(begin, end, region_response);
1255  }
1256 
1257 };
1258 
1260 
1261 
1262 /** This Functor chooses a random value of a column
1263  */
1265 {
1266 public:
1267 
1269  ArrayVector<double> class_weights_;
1270  ArrayVector<double> bestCurrentCounts[2];
1271  double min_gini_;
1272  std::ptrdiff_t min_index_;
1273  double min_threshold_;
1274  ProblemSpec<> ext_param_;
1275  typedef RandomMT19937 Random_t;
1276  Random_t random;
1277 
1279  {}
1280 
1281  template<class T>
1282  RandomSplitOfColumn(ProblemSpec<T> const & ext)
1283  :
1284  class_weights_(ext.class_weights_),
1285  ext_param_(ext),
1286  random(RandomSeed)
1287  {
1288  bestCurrentCounts[0].resize(ext.class_count_);
1289  bestCurrentCounts[1].resize(ext.class_count_);
1290  }
1291 
1292  template<class T>
1293  RandomSplitOfColumn(ProblemSpec<T> const & ext, Random_t & random_)
1294  :
1295  class_weights_(ext.class_weights_),
1296  ext_param_(ext),
1297  random(random_)
1298  {
1299  bestCurrentCounts[0].resize(ext.class_count_);
1300  bestCurrentCounts[1].resize(ext.class_count_);
1301  }
1302 
1303  template<class T>
1304  void set_external_parameters(ProblemSpec<T> const & ext)
1305  {
1306  class_weights_ = ext.class_weights_;
1307  ext_param_ = ext;
1308  bestCurrentCounts[0].resize(ext.class_count_);
1309  bestCurrentCounts[1].resize(ext.class_count_);
1310  }
1311 
1312  template< class DataSourceF_t,
1313  class DataSource_t,
1314  class I_Iter,
1315  class Array>
1316  void operator()(DataSourceF_t const & column,
1317  DataSource_t const & labels,
1318  I_Iter & begin,
1319  I_Iter & end,
1320  Array const & region_response)
1321  {
1322  std::sort(begin, end,
1324  typedef typename
1325  LossTraits<LineSearchLossTag, DataSource_t>::type LineSearchLoss;
1326  LineSearchLoss left(labels, ext_param_);
1327  LineSearchLoss right(labels, ext_param_);
1328  right.init(begin, end, region_response);
1329 
1330 
1331  min_gini_ = NumericTraits<double>::max();
1332  int tmp_pt = random.uniformInt(std::distance(begin, end));
1333  min_index_ = tmp_pt;
1334  min_threshold_ = column[*(begin + min_index_)];
1336  sorter(column, 0, min_threshold_);
1337  I_Iter part = std::partition(begin, end, sorter);
1338  DimensionNotEqual<DataSourceF_t> comp(column, 0);
1339  if(part == begin)
1340  {
1341  part= std::adjacent_find(part, end, comp)+1;
1342 
1343  }
1344  if(part >= end)
1345  {
1346  return;
1347  }
1348  else
1349  {
1350  min_threshold_ = column[*part];
1351  }
1352  min_gini_ = right.decrement(begin, part)
1353  + left.increment(begin , part);
1354 
1355  bestCurrentCounts[0] = left.response();
1356  bestCurrentCounts[1] = right.response();
1357 
1358  min_index_ = part - begin;
1359  }
1360 
1361  template<class DataSource_t, class Iter, class Array>
1362  double loss_of_region(DataSource_t const & labels,
1363  Iter & begin,
1364  Iter & end,
1365  Array const & region_response) const
1366  {
1367  typedef typename
1368  LossTraits<LineSearchLossTag, DataSource_t>::type LineSearchLoss;
1369  LineSearchLoss region_loss(labels, ext_param_);
1370  return
1371  region_loss.init(begin, end, region_response);
1372  }
1373 
1374 };
1375 
1377 }
1378 }
1379 
1380 
1381 } //namespace vigra
1382 #endif // VIGRA_RANDOM_FOREST_SPLIT_HXX
UInt32 uniformInt() const
Definition: random.hxx:464
static double impurity(Array const &hist, double total)
Definition: rf_split.hxx:443
Definition: rf_region.hxx:57
Definition: rf_nodeproxy.hxx:626
MultiArrayView< 2, T, C > columnVector(MultiArrayView< 2, T, C > const &m, MultiArrayIndex d)
Definition: matrix.hxx:727
Definition: rf_split.hxx:201
const difference_type & shape() const
Definition: multi_array.hxx:1648
Definition: rf_split.hxx:993
Definition: rf_split.hxx:305
const_iterator begin() const
Definition: array_vector.hxx:223
void set_external_parameters(ProblemSpec< T > const &in)
Definition: rf_split.hxx:112
problem specification class for the random forest.
Definition: rf_common.hxx:538
iterator begin()
Definition: multi_array.hxx:1921
int findBestSplit(MultiArrayView< 2, T, C >, MultiArrayView< 2, T2, C2 >, Region, ArrayVector< Region >, Random)
Definition: rf_split.hxx:150
Definition: rf_split.hxx:356
std::ptrdiff_t MultiArrayIndex
Definition: multi_fwd.hxx:60
double operator()(Array const &hist, Array2 const &weights, double total=1.0) const
Definition: rf_split.hxx:425
static double impurity(Array const &hist, Array2 const &weights, double total)
Definition: rf_split.hxx:389
Definition: rf_nodeproxy.hxx:87
double operator()(Array const &hist, double total=1.0) const
Definition: rf_split.hxx:435
NumericTraits< T >::Promote sq(T t)
The square function.
Definition: mathutil.hxx:382
Definition: rf_split.hxx:831
double operator()(Array const &hist, double total=1.0) const
Definition: rf_split.hxx:373
TinyVector< MultiArrayIndex, N > type
Definition: multi_shape.hxx:272
void reset()
Definition: rf_split.hxx:137
bool closeAtTolerance(T1 l, T2 r, typename PromoteTraits< T1, T2 >::Promote epsilon)
Tolerance based floating-point equality.
Definition: mathutil.hxx:1638
linalg::TemporaryMatrix< T > log(MultiArrayView< 2, T, C > const &v)
void operator()(DataSourceF_t const &column, DataSource_t const &labels, I_Iter &begin, I_Iter &end, Array const &region_response)
Definition: rf_split.hxx:892
static double impurity(Array const &hist, double total)
Definition: rf_split.hxx:381
Definition: rf_split.hxx:1163
MultiArrayView< 2, T, C > rowVector(MultiArrayView< 2, T, C > const &m, MultiArrayIndex d)
Definition: matrix.hxx:697
Definition: random.hxx:336
Base class for, and view to, vigra::MultiArray.
Definition: multi_array.hxx:704
static double impurity(Array const &hist, Array2 const &weights, double total)
Definition: rf_split.hxx:451
Definition: rf_split.hxx:1264
size_type size() const
Definition: array_vector.hxx:358
int floor(FixedPoint< IntBits, FracBits > v)
rounding down.
Definition: fixedpoint.hxx:667
double & weights()
Definition: rf_nodeproxy.hxx:115
Definition: rf_split.hxx:92
double operator()(Array const &hist, Array2 const &weights, double total=1.0) const
Definition: rf_split.hxx:363
int makeTerminalNode(MultiArrayView< 2, T, C >, MultiArrayView< 2, T2, C2 >, Region &region, Random)
Definition: rf_split.hxx:168
Definition: rf_split.hxx:418

© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de)
Heidelberg Collaboratory for Image Processing, University of Heidelberg, Germany

html generated using doxygen and Python
vigra 1.11.1 (Fri May 19 2017)