35 #ifndef VIGRA_RANDOM_FOREST_SPLIT_HXX
36 #define VIGRA_RANDOM_FOREST_SPLIT_HXX
42 #include "../mathutil.hxx"
43 #include "../array_vector.hxx"
44 #include "../sized_int.hxx"
45 #include "../matrix.hxx"
46 #include "../random.hxx"
47 #include "../functorexpression.hxx"
48 #include "rf_nodeproxy.hxx"
50 #include "rf_region.hxx"
59 class CompileTimeError;
69 static void exec(Iter , Iter )
74 class Normalise<ClassificationTag>
78 static void exec (Iter begin, Iter end)
80 double bla = std::accumulate(begin, end, 0.0);
81 for(
int ii = 0; ii < end - begin; ++ii)
82 begin[ii] = begin[ii]/bla ;
115 t_data.push_back(in.column_count_);
116 t_data.push_back(in.class_count_);
124 int classCount()
const
126 return int(t_data[1]);
129 int featureCount()
const
131 return int(t_data[0]);
149 template<
class T,
class C,
class T2,
class C2,
class Region,
class Random>
158 CompileTimeError SplitFunctor__findBestSplit_member_was_not_defined;
167 template<
class T,
class C,
class T2,
class C2,
class Region,
class Random>
175 if(ext_param_.class_weights_.
size() != region.classCounts().size())
177 std::copy(region.classCounts().begin(),
178 region.classCounts().end(),
183 std::transform(region.classCounts().begin(),
184 region.classCounts().end(),
185 ext_param_.class_weights_.
begin(),
186 ret.prob_begin(), std::multiplies<double>());
188 detail::Normalise<RF_Tag>::exec(ret.prob_begin(), ret.prob_end());
192 return e_ConstProbNode;
200 template<
class DataMatrix>
203 DataMatrix
const & data_;
210 double thresVal = 0.0)
212 sortColumn_(sortColumn),
218 sortColumn_ = sortColumn;
220 void setThreshold(
double value)
227 return data_(l, sortColumn_) < data_(r, sortColumn_);
231 return data_(l, sortColumn_) < thresVal_;
235 template<
class DataMatrix>
236 class DimensionNotEqual
238 DataMatrix
const & data_;
243 DimensionNotEqual(DataMatrix
const & data,
246 sortColumn_(sortColumn)
251 sortColumn_ = sortColumn;
256 return data_(l, sortColumn_) != data_(r, sortColumn_);
260 template<
class DataMatrix>
261 class SortSamplesByHyperplane
263 DataMatrix
const & data_;
264 Node<i_HyperplaneNode>
const & node_;
268 SortSamplesByHyperplane(DataMatrix
const & data,
269 Node<i_HyperplaneNode>
const & node)
279 double result_l = -1 * node_.intercept();
280 for(
int ii = 0; ii < node_.columns_size(); ++ii)
282 result_l +=
rowVector(data_, l)[node_.columns_begin()[ii]]
283 * node_.weights()[ii];
290 return (*
this)[l] < (*this)[r];
304 template <
class DataSource,
class CountArray>
307 DataSource
const & labels_;
308 CountArray & counts_;
327 counts_[labels_[l]] +=1;
342 double operator[](
size_t)
const
362 template<
class Array,
class Array2>
364 Array2
const & weights,
365 double total = 1.0)
const
367 return impurity(hist, weights, total);
372 template<
class Array>
373 double operator()(Array
const & hist,
double total = 1.0)
const
380 template<
class Array>
381 static double impurity(Array
const & hist,
double total)
383 return impurity(hist, detail::ConstArr<1>(), total);
388 template<
class Array,
class Array2>
390 Array2
const & weights,
394 int class_count = hist.size();
395 double entropy = 0.0;
398 double p0 = (hist[0]/total);
399 double p1 = (hist[1]/total);
404 for(
int ii = 0; ii < class_count; ++ii)
406 double w = weights[ii];
407 double pii = hist[ii]/total;
411 entropy = total * entropy;
424 template<
class Array,
class Array2>
426 Array2
const & weights,
427 double total = 1.0)
const
429 return impurity(hist, weights, total);
434 template<
class Array>
435 double operator()(Array
const & hist,
double total = 1.0)
const
442 template<
class Array>
443 static double impurity(Array
const & hist,
double total)
445 return impurity(hist, detail::ConstArr<1>(), total);
450 template<
class Array,
class Array2>
452 Array2
const & weights,
456 int class_count = hist.size();
460 double w = weights[0] * weights[1];
461 gini = w * (hist[0] * hist[1] / total);
465 for(
int ii = 0; ii < class_count; ++ii)
467 double w = weights[ii];
468 gini += w*( hist[ii]*( 1.0 - w * hist[ii]/total ) );
476 template <
class DataSource,
class Impurity= GiniCriterion>
480 DataSource
const & labels_;
481 ArrayVector<double> counts_;
482 ArrayVector<double>
const class_weights_;
483 double total_counts_;
489 ImpurityLoss(DataSource
const & labels,
490 ProblemSpec<T>
const & ext_)
492 counts_(ext_.class_count_, 0.0),
493 class_weights_(ext_.class_weights_),
503 template<
class Counts>
504 double increment_histogram(Counts
const & counts)
506 std::transform(counts.begin(), counts.end(),
507 counts_.begin(), counts_.begin(),
508 std::plus<double>());
509 total_counts_ = std::accumulate( counts_.begin(),
512 return impurity_(counts_, class_weights_, total_counts_);
515 template<
class Counts>
516 double decrement_histogram(Counts
const & counts)
518 std::transform(counts.begin(), counts.end(),
519 counts_.begin(), counts_.begin(),
520 std::minus<double>());
521 total_counts_ = std::accumulate( counts_.begin(),
524 return impurity_(counts_, class_weights_, total_counts_);
528 double increment(Iter begin, Iter end)
530 for(Iter iter = begin; iter != end; ++iter)
532 counts_[labels_(*iter, 0)] +=1.0;
535 return impurity_(counts_, class_weights_, total_counts_);
539 double decrement(Iter
const & begin, Iter
const & end)
541 for(Iter iter = begin; iter != end; ++iter)
543 counts_[labels_(*iter,0)] -=1.0;
546 return impurity_(counts_, class_weights_, total_counts_);
549 template<
class Iter,
class Resp_t>
550 double init (Iter , Iter , Resp_t resp)
553 std::copy(resp.begin(), resp.end(), counts_.begin());
554 total_counts_ = std::accumulate(counts_.begin(), counts_.end(), 0.0);
555 return impurity_(counts_,class_weights_, total_counts_);
558 ArrayVector<double>
const & response()
566 template <
class DataSource>
567 class RegressionForestCounter
571 DataSource
const & labels_;
572 ArrayVector <double> mean_;
573 ArrayVector <double> variance_;
574 ArrayVector <double> tmp_;
579 RegressionForestCounter(DataSource
const & labels,
580 ProblemSpec<T>
const & ext_)
583 mean_(ext_.response_size_, 0.0),
584 variance_(ext_.response_size_, 0.0),
585 tmp_(ext_.response_size_),
590 double increment (Iter begin, Iter end)
592 for(Iter iter = begin; iter != end; ++iter)
595 for(
unsigned int ii = 0; ii < mean_.size(); ++ii)
596 tmp_[ii] = labels_(*iter, ii) - mean_[ii];
597 double f = 1.0 / count_,
599 for(
unsigned int ii = 0; ii < mean_.size(); ++ii)
600 mean_[ii] += f*tmp_[ii];
601 for(
unsigned int ii = 0; ii < mean_.size(); ++ii)
602 variance_[ii] += f1*
sq(tmp_[ii]);
604 double res = std::accumulate(variance_.begin(),
607 std::plus<double>());
613 double decrement (Iter begin, Iter end)
615 for(Iter iter = begin; iter != end; ++iter)
624 for(
unsigned int ii = 0; ii < mean_.size(); ++ii)
627 for(Iter iter = begin; iter != end; ++iter)
629 mean_[ii] += labels_(*iter, ii);
633 for(Iter iter = begin; iter != end; ++iter)
635 variance_[ii] += (labels_(*iter, ii) - mean_[ii])*(labels_(*iter, ii) - mean_[ii]);
638 double res = std::accumulate(variance_.begin(),
641 std::plus<double>());
647 template<
class Iter,
class Resp_t>
648 double init (Iter begin, Iter end, Resp_t )
651 return this->increment(begin, end);
656 ArrayVector<double>
const & response()
670 template <
class DataSource>
671 class RegressionForestCounter2
675 DataSource
const & labels_;
676 ArrayVector <double> mean_;
677 ArrayVector <double> variance_;
678 ArrayVector <double> tmp_;
682 RegressionForestCounter2(DataSource
const & labels,
683 ProblemSpec<T>
const & ext_)
686 mean_(ext_.response_size_, 0.0),
687 variance_(ext_.response_size_, 0.0),
688 tmp_(ext_.response_size_),
693 double increment (Iter begin, Iter end)
695 for(Iter iter = begin; iter != end; ++iter)
698 for(
int ii = 0; ii < mean_.size(); ++ii)
699 tmp_[ii] = labels_(*iter, ii) - mean_[ii];
700 double f = 1.0 / count_,
702 for(
int ii = 0; ii < mean_.size(); ++ii)
703 mean_[ii] += f*tmp_[ii];
704 for(
int ii = 0; ii < mean_.size(); ++ii)
705 variance_[ii] += f1*
sq(tmp_[ii]);
707 double res = std::accumulate(variance_.begin(),
711 /((count_ == 1)? 1:(count_ -1));
717 double decrement (Iter begin, Iter end)
719 for(Iter iter = begin; iter != end; ++iter)
721 double f = 1.0 / count_,
723 for(
int ii = 0; ii < mean_.size(); ++ii)
724 mean_[ii] = (mean_[ii] - f*labels_(*iter,ii))/(1-f);
725 for(
int ii = 0; ii < mean_.size(); ++ii)
726 variance_[ii] -= f1*
sq(labels_(*iter,ii) - mean_[ii]);
729 double res = std::accumulate(variance_.begin(),
733 /((count_ == 1)? 1:(count_ -1));
783 template<
class Iter,
class Resp_t>
784 double init (Iter begin, Iter end, Resp_t resp)
787 return this->increment(begin, end, resp);
791 ArrayVector<double>
const & response()
804 template<
class Tag,
class Datatyp>
810 template<
class Datatype>
811 struct LossTraits<GiniCriterion, Datatype>
813 typedef ImpurityLoss<Datatype, GiniCriterion> type;
816 template<
class Datatype>
817 struct LossTraits<EntropyCriterion, Datatype>
819 typedef ImpurityLoss<Datatype, EntropyCriterion> type;
822 template<
class Datatype>
823 struct LossTraits<LSQLoss, Datatype>
825 typedef RegressionForestCounter<Datatype> type;
830 template<
class LineSearchLossTag>
837 std::ptrdiff_t min_index_;
838 double min_threshold_;
847 class_weights_(ext.class_weights_),
850 bestCurrentCounts[0].resize(ext.class_count_);
851 bestCurrentCounts[1].resize(ext.class_count_);
856 class_weights_ = ext.class_weights_;
858 bestCurrentCounts[0].resize(ext.class_count_);
859 bestCurrentCounts[1].resize(ext.class_count_);
888 template<
class DataSourceF_t,
893 DataSource_t
const & labels,
896 Array
const & region_response)
898 std::sort(begin, end,
901 LossTraits<LineSearchLossTag, DataSource_t>::type LineSearchLoss;
902 LineSearchLoss left(labels, ext_param_);
903 LineSearchLoss right(labels, ext_param_);
907 min_gini_ = right.init(begin, end, region_response);
908 min_threshold_ = *begin;
910 DimensionNotEqual<DataSourceF_t> comp(column, 0);
913 I_Iter next = std::adjacent_find(iter, end, comp);
917 double lr = right.decrement(iter, next + 1);
918 double ll = left.increment(iter , next + 1);
919 double loss = lr +ll;
921 #ifdef CLASSIFIER_TEST
924 if(loss < min_gini_ )
927 bestCurrentCounts[0] = left.response();
928 bestCurrentCounts[1] = right.response();
929 #ifdef CLASSIFIER_TEST
930 min_gini_ = loss < min_gini_? loss : min_gini_;
934 min_index_ = next - begin +1 ;
935 min_threshold_ = (double(column(*next,0)) + double(column(*(next +1), 0)))/2.0;
938 next = std::adjacent_find(iter, end, comp);
945 template<
class DataSource_t,
class Iter,
class Array>
946 double loss_of_region(DataSource_t
const & labels,
949 Array
const & region_response)
const
952 LossTraits<LineSearchLossTag, DataSource_t>::type LineSearchLoss;
953 LineSearchLoss region_loss(labels, ext_param_);
955 region_loss.init(begin, end, region_response);
965 template<
class Region,
class LabelT>
966 static void exec(Region & , LabelT & )
971 struct Correction<ClassificationTag>
973 template<
class Region,
class LabelT>
974 static void exec(Region & region, LabelT & labels)
976 if(std::accumulate(region.classCounts().begin(),
977 region.classCounts().end(), 0.0) != region.size())
979 RandomForestClassCounter< LabelT,
980 ArrayVector<double> >
981 counter(labels, region.classCounts());
982 std::for_each( region.begin(), region.end(), counter);
983 region.classCountsIsValid =
true;
992 template<
class ColumnDecisionFunctor,
class Tag = ClassificationTag>
1001 ColumnDecisionFunctor bgfunc;
1003 double region_gini_;
1010 double minGini()
const
1012 return min_gini_[bestSplitIndex];
1014 int bestSplitColumn()
const
1016 return splitColumns[bestSplitIndex];
1018 double bestSplitThreshold()
const
1020 return min_thresholds_[bestSplitIndex];
1027 bgfunc.set_external_parameters( SB::ext_param_);
1028 int featureCount_ = SB::ext_param_.column_count_;
1029 splitColumns.resize(featureCount_);
1030 for(
int k=0; k<featureCount_; ++k)
1031 splitColumns[k] = k;
1032 min_gini_.resize(featureCount_);
1033 min_indices_.resize(featureCount_);
1034 min_thresholds_.resize(featureCount_);
1038 template<
class T,
class C,
class T2,
class C2,
class Region,
class Random>
1046 typedef typename Region::IndexIterator IndexIterator;
1047 if(region.size() == 0)
1049 std::cerr <<
"SplitFunctor::findBestSplit(): stackentry with 0 examples encountered\n"
1050 "continuing learning process....";
1053 detail::Correction<Tag>::exec(region, labels);
1057 region_gini_ = bgfunc.loss_of_region(labels,
1060 region.classCounts());
1061 if(region_gini_ <= SB::ext_param_.precision_)
1065 for(
int ii = 0; ii < SB::ext_param_.actual_mtry_; ++ii)
1066 std::swap(splitColumns[ii],
1067 splitColumns[ii+ randint(features.
shape(1) - ii)]);
1071 double current_min_gini = region_gini_;
1072 int num2try = features.
shape(1);
1073 for(
int k=0; k<num2try; ++k)
1078 region.
begin(), region.end(),
1079 region.classCounts());
1080 min_gini_[k] = bgfunc.min_gini_;
1081 min_indices_[k] = bgfunc.min_index_;
1082 min_thresholds_[k] = bgfunc.min_threshold_;
1083 #ifdef CLASSIFIER_TEST
1084 if( bgfunc.min_gini_ < current_min_gini
1087 if(bgfunc.min_gini_ < current_min_gini)
1090 current_min_gini = bgfunc.min_gini_;
1091 childRegions[0].classCounts() = bgfunc.bestCurrentCounts[0];
1092 childRegions[1].classCounts() = bgfunc.bestCurrentCounts[1];
1093 childRegions[0].classCountsIsValid =
true;
1094 childRegions[1].classCountsIsValid =
true;
1097 num2try = SB::ext_param_.actual_mtry_;
1108 Node<i_ThresholdNode> node(SB::t_data, SB::p_data);
1110 node.threshold() = min_thresholds_[bestSplitIndex];
1111 node.column() = splitColumns[bestSplitIndex];
1115 sorter(features, node.column(), node.threshold());
1116 IndexIterator bestSplit =
1117 std::partition(region.begin(), region.end(), sorter);
1119 childRegions[0].setRange( region.begin() , bestSplit );
1120 childRegions[0].rule = region.rule;
1121 childRegions[0].rule.push_back(std::make_pair(1, 1.0));
1122 childRegions[1].setRange( bestSplit , region.end() );
1123 childRegions[1].rule = region.rule;
1124 childRegions[1].rule.push_back(std::make_pair(1, 1.0));
1126 return i_ThresholdNode;
1171 std::ptrdiff_t min_index_;
1172 double min_threshold_;
1181 class_weights_(ext.class_weights_),
1184 bestCurrentCounts[0].resize(ext.class_count_);
1185 bestCurrentCounts[1].resize(ext.class_count_);
1191 class_weights_ = ext.class_weights_;
1193 bestCurrentCounts[0].resize(ext.class_count_);
1194 bestCurrentCounts[1].resize(ext.class_count_);
1197 template<
class DataSourceF_t,
1201 void operator()(DataSourceF_t
const & column,
1202 DataSource_t
const & labels,
1205 Array
const & region_response)
1207 std::sort(begin, end,
1210 LossTraits<LineSearchLossTag, DataSource_t>::type LineSearchLoss;
1211 LineSearchLoss left(labels, ext_param_);
1212 LineSearchLoss right(labels, ext_param_);
1213 right.init(begin, end, region_response);
1215 min_gini_ = NumericTraits<double>::max();
1216 min_index_ =
floor(
double(end - begin)/2.0);
1217 min_threshold_ = column[*(begin + min_index_)];
1219 sorter(column, 0, min_threshold_);
1220 I_Iter part = std::partition(begin, end, sorter);
1221 DimensionNotEqual<DataSourceF_t> comp(column, 0);
1224 part= std::adjacent_find(part, end, comp)+1;
1233 min_threshold_ = column[*part];
1235 min_gini_ = right.decrement(begin, part)
1236 + left.increment(begin , part);
1238 bestCurrentCounts[0] = left.response();
1239 bestCurrentCounts[1] = right.response();
1241 min_index_ = part - begin;
1244 template<
class DataSource_t,
class Iter,
class Array>
1245 double loss_of_region(DataSource_t
const & labels,
1248 Array
const & region_response)
const
1251 LossTraits<LineSearchLossTag, DataSource_t>::type LineSearchLoss;
1252 LineSearchLoss region_loss(labels, ext_param_);
1254 region_loss.init(begin, end, region_response);
1272 std::ptrdiff_t min_index_;
1273 double min_threshold_;
1284 class_weights_(ext.class_weights_),
1288 bestCurrentCounts[0].resize(ext.class_count_);
1289 bestCurrentCounts[1].resize(ext.class_count_);
1295 class_weights_(ext.class_weights_),
1299 bestCurrentCounts[0].resize(ext.class_count_);
1300 bestCurrentCounts[1].resize(ext.class_count_);
1306 class_weights_ = ext.class_weights_;
1308 bestCurrentCounts[0].resize(ext.class_count_);
1309 bestCurrentCounts[1].resize(ext.class_count_);
1312 template<
class DataSourceF_t,
1316 void operator()(DataSourceF_t
const & column,
1317 DataSource_t
const & labels,
1320 Array
const & region_response)
1322 std::sort(begin, end,
1325 LossTraits<LineSearchLossTag, DataSource_t>::type LineSearchLoss;
1326 LineSearchLoss left(labels, ext_param_);
1327 LineSearchLoss right(labels, ext_param_);
1328 right.init(begin, end, region_response);
1331 min_gini_ = NumericTraits<double>::max();
1332 int tmp_pt = random.
uniformInt(std::distance(begin, end));
1333 min_index_ = tmp_pt;
1334 min_threshold_ = column[*(begin + min_index_)];
1336 sorter(column, 0, min_threshold_);
1337 I_Iter part = std::partition(begin, end, sorter);
1338 DimensionNotEqual<DataSourceF_t> comp(column, 0);
1341 part= std::adjacent_find(part, end, comp)+1;
1350 min_threshold_ = column[*part];
1352 min_gini_ = right.decrement(begin, part)
1353 + left.increment(begin , part);
1355 bestCurrentCounts[0] = left.response();
1356 bestCurrentCounts[1] = right.response();
1358 min_index_ = part - begin;
1361 template<
class DataSource_t,
class Iter,
class Array>
1362 double loss_of_region(DataSource_t
const & labels,
1365 Array
const & region_response)
const
1368 LossTraits<LineSearchLossTag, DataSource_t>::type LineSearchLoss;
1369 LineSearchLoss region_loss(labels, ext_param_);
1371 region_loss.init(begin, end, region_response);
1382 #endif // VIGRA_RANDOM_FOREST_SPLIT_HXX
UInt32 uniformInt() const
Definition: random.hxx:464
static double impurity(Array const &hist, double total)
Definition: rf_split.hxx:443
Definition: rf_region.hxx:57
Definition: rf_nodeproxy.hxx:626
MultiArrayView< 2, T, C > columnVector(MultiArrayView< 2, T, C > const &m, MultiArrayIndex d)
Definition: matrix.hxx:727
Definition: rf_split.hxx:201
const difference_type & shape() const
Definition: multi_array.hxx:1648
Definition: rf_split.hxx:993
Definition: rf_split.hxx:305
const_iterator begin() const
Definition: array_vector.hxx:223
void set_external_parameters(ProblemSpec< T > const &in)
Definition: rf_split.hxx:112
problem specification class for the random forest.
Definition: rf_common.hxx:538
iterator begin()
Definition: multi_array.hxx:1921
int findBestSplit(MultiArrayView< 2, T, C >, MultiArrayView< 2, T2, C2 >, Region, ArrayVector< Region >, Random)
Definition: rf_split.hxx:150
Definition: rf_split.hxx:356
std::ptrdiff_t MultiArrayIndex
Definition: multi_fwd.hxx:60
double operator()(Array const &hist, Array2 const &weights, double total=1.0) const
Definition: rf_split.hxx:425
static double impurity(Array const &hist, Array2 const &weights, double total)
Definition: rf_split.hxx:389
Definition: rf_nodeproxy.hxx:87
double operator()(Array const &hist, double total=1.0) const
Definition: rf_split.hxx:435
NumericTraits< T >::Promote sq(T t)
The square function.
Definition: mathutil.hxx:382
Definition: rf_split.hxx:831
double operator()(Array const &hist, double total=1.0) const
Definition: rf_split.hxx:373
TinyVector< MultiArrayIndex, N > type
Definition: multi_shape.hxx:272
void reset()
Definition: rf_split.hxx:137
bool closeAtTolerance(T1 l, T2 r, typename PromoteTraits< T1, T2 >::Promote epsilon)
Tolerance based floating-point equality.
Definition: mathutil.hxx:1638
linalg::TemporaryMatrix< T > log(MultiArrayView< 2, T, C > const &v)
void operator()(DataSourceF_t const &column, DataSource_t const &labels, I_Iter &begin, I_Iter &end, Array const ®ion_response)
Definition: rf_split.hxx:892
static double impurity(Array const &hist, double total)
Definition: rf_split.hxx:381
MultiArrayView< 2, T, C > rowVector(MultiArrayView< 2, T, C > const &m, MultiArrayIndex d)
Definition: matrix.hxx:697
Definition: random.hxx:336
Base class for, and view to, vigra::MultiArray.
Definition: multi_array.hxx:704
static double impurity(Array const &hist, Array2 const &weights, double total)
Definition: rf_split.hxx:451
Definition: rf_split.hxx:1264
size_type size() const
Definition: array_vector.hxx:358
int floor(FixedPoint< IntBits, FracBits > v)
rounding down.
Definition: fixedpoint.hxx:667
double & weights()
Definition: rf_nodeproxy.hxx:115
Definition: rf_split.hxx:92
double operator()(Array const &hist, Array2 const &weights, double total=1.0) const
Definition: rf_split.hxx:363
int makeTerminalNode(MultiArrayView< 2, T, C >, MultiArrayView< 2, T2, C2 >, Region ®ion, Random)
Definition: rf_split.hxx:168
Definition: rf_split.hxx:418