12 #ifndef VIGRA_RANDOM_FOREST_RIDGE_SPLIT_H
13 #define VIGRA_RANDOM_FOREST_RIDGE_SPLIT_H
15 #include "../sampling.hxx"
16 #include "rf_split.hxx"
17 #include "rf_nodeproxy.hxx"
18 #include "../regression.hxx"
20 #define outm(v) std::cout << (#v) << ": " << (v) << std::endl;
21 #define outm2(v) std::cout << (#v) << ": " << (v) << ", ";
70 template<
class ColumnDecisionFunctor,
class Tag = ClassificationTag>
71 class RidgeSplit:
public SplitBase<Tag>
76 typedef SplitBase<Tag> SB;
78 ArrayVector<Int32> splitColumns;
79 ColumnDecisionFunctor bgfunc;
82 ArrayVector<double> min_gini_;
83 ArrayVector<std::ptrdiff_t> min_indices_;
84 ArrayVector<double> min_thresholds_;
89 bool m_bDoScalingInTraining;
90 bool m_bDoBestLambdaBasedOnGini;
93 :m_bDoScalingInTraining(true),
94 m_bDoBestLambdaBasedOnGini(true)
98 double minGini()
const
100 return min_gini_[bestSplitIndex];
103 int bestSplitColumn()
const
105 return splitColumns[bestSplitIndex];
108 bool& doScalingInTraining()
109 {
return m_bDoScalingInTraining; }
111 bool& doBestLambdaBasedOnGini()
112 {
return m_bDoBestLambdaBasedOnGini; }
115 void set_external_parameters(ProblemSpec<T>
const & in)
118 bgfunc.set_external_parameters(in);
119 int featureCount_ = in.column_count_;
120 splitColumns.resize(featureCount_);
121 for(
int k=0; k<featureCount_; ++k)
123 min_gini_.resize(featureCount_);
124 min_indices_.resize(featureCount_);
125 min_thresholds_.resize(featureCount_);
129 template<
class T,
class C,
class T2,
class C2,
class Region,
class Random>
131 MultiArrayView<2, T2, C2> multiClassLabels,
133 ArrayVector<Region>& childRegions,
138 typedef typename MultiArrayView <2, T, C>::difference_type fShape;
144 if(std::accumulate(region.classCounts().begin(),
145 region.classCounts().end(), 0) != region.size())
147 RandomForestClassCounter< MultiArrayView<2,T2, C2>,
148 ArrayVector<double> >
149 counter(multiClassLabels, region.classCounts());
150 std::for_each( region.begin(), region.end(), counter);
151 region.classCountsIsValid =
true;
158 if(region_gini_ == 0 || region.size() < SB::ext_param_.actual_mtry_ || region.oob_size() < 2)
162 for(
int ii = 0; ii < SB::ext_param_.actual_mtry_; ++ii)
163 std::swap(splitColumns[ii],
164 splitColumns[ii+ randint(features.shape(1) - ii)]);
167 MultiArray<2, T2> labels(lShape(multiClassLabels.shape(0),1));
170 for(
int n=0; n<static_cast<int>(region.classCounts().size()); n++)
171 nNumClasses+=((region.classCounts()[n]>0) ? 1:0);
177 int nMaxClassCounts=0;
178 for(
int n=0; n<static_cast<int>(region.classCounts().size()); n++)
182 if(region.classCounts()[n]>nMaxClassCounts)
184 nMaxClassCounts=region.classCounts()[n];
190 for(
int n=0; n<multiClassLabels.shape(0); n++)
191 labels(n,0)=((multiClassLabels(n,0)==nMaxClass) ? 1:0);
194 labels=multiClassLabels;
228 MultiArrayView<2, T, C> cVector;
229 MultiArray<2, T> xtrain(fShape(region.size(),SB::ext_param_.actual_mtry_));
231 MultiArray<2, double> regrLabels(dShape(region.size(),1));
234 MultiArray<2, double> meanMatrix(dShape(SB::ext_param_.actual_mtry_,1));
235 MultiArray<2, double> stdMatrix(dShape(SB::ext_param_.actual_mtry_,1));
236 for(
int m=0; m<SB::ext_param_.actual_mtry_; m++)
241 double dCurrFeatureColumnMean=0.0;
242 double dCurrFeatureColumnStd=1.0;
245 for(
int n=0; n<region.size(); n++)
246 dCurrFeatureColumnMean+=cVector[region[n]];
247 dCurrFeatureColumnMean/=region.size();
249 if(m_bDoScalingInTraining)
251 for(
int n=0; n<region.size(); n++)
253 dCurrFeatureColumnStd+=
254 (cVector[region[n]]-dCurrFeatureColumnMean)*(cVector[region[n]]-dCurrFeatureColumnMean);
257 dCurrFeatureColumnStd=
sqrt(dCurrFeatureColumnStd/(region.size()-1));
260 stdMatrix(m,0)=dCurrFeatureColumnStd;
262 meanMatrix(m,0)=dCurrFeatureColumnMean;
266 for(
int n=0; n<region.size(); n++)
267 xtrain(n,m)=(cVector[region[n]]-dCurrFeatureColumnMean)/dCurrFeatureColumnStd;
272 for(
int n=0; n<region.size(); n++)
277 regrLabels(n,0)=((labels[region[n]]==0) ? -1:1);
280 MultiArray<2, double> dLambdas(dShape(11,1));
282 for(
int nLambda=-5; nLambda<=5; nLambda++)
283 dLambdas[nCounter++]=pow(10.0,nLambda);
285 MultiArray<2, double> regrCoef(dShape(SB::ext_param_.actual_mtry_,11));
288 double dMaxRidgeSum=NumericTraits<double>::min();
289 double dCurrRidgeSum;
290 int nMaxRidgeSumAtLambdaInd=0;
292 for(
int nLambdaInd=0; nLambdaInd<11; nLambdaInd++)
300 MultiArray<2, double> dDistanceFromHyperplane(dShape(features.shape(0),1));
302 for(
int n=0; n<region.oob_size(); n++)
304 dDistanceFromHyperplane(region.oob_begin()[n],0)=0.0;
305 for (
int m=0; m<SB::ext_param_.actual_mtry_; m++)
307 dDistanceFromHyperplane(region.oob_begin()[n],0)+=
308 features(region.oob_begin()[n],splitColumns[m])*regrCoef(m,nLambdaInd);
312 double dCurrIntercept=0.0;
313 if(m_bDoBestLambdaBasedOnGini)
316 bgfunc(dDistanceFromHyperplane,
318 region.oob_begin(), region.oob_end(),
319 region.classCounts());
320 dCurrIntercept=bgfunc.min_threshold_;
324 for (
int m=0; m<SB::ext_param_.actual_mtry_; m++)
325 dCurrIntercept+=meanMatrix(m,0)*regrCoef(m,nLambdaInd);
328 for(
int n=0; n<region.oob_size(); n++)
331 int nClassPrediction=((dDistanceFromHyperplane(region.oob_begin()[n],0) >=dCurrIntercept) ? 1:0);
332 dCurrRidgeSum+=((nClassPrediction == labels(region.oob_begin()[n],0)) ? 1:0);
334 if(dCurrRidgeSum>dMaxRidgeSum)
336 dMaxRidgeSum=dCurrRidgeSum;
337 nMaxRidgeSumAtLambdaInd=nLambdaInd;
343 Node<i_HyperplaneNode> node(SB::ext_param_.actual_mtry_, SB::t_data, SB::p_data);
347 MultiArray<2, double> dCoeffVector(dShape(SB::ext_param_.actual_mtry_,1));
348 for(
int n=0; n<SB::ext_param_.actual_mtry_; n++)
349 dCoeffVector(n,0)=regrCoef(n,nMaxRidgeSumAtLambdaInd)*stdMatrix(n,0);
352 double dVnorm=
columnVector(regrCoef,nMaxRidgeSumAtLambdaInd).norm();
354 for(
int n=0; n<SB::ext_param_.actual_mtry_; n++)
355 node.weights()[n]=dCoeffVector(n,0)/dVnorm;
359 node.column_data()[0]=SB::ext_param_.actual_mtry_;
360 for(
int n=0; n<SB::ext_param_.actual_mtry_; n++)
361 node.column_data()[n+1]=splitColumns[n];
367 MultiArray<2, double> dDistanceFromHyperplane(dShape(features.shape(0),1));
369 for(
int n=0; n<region.size(); n++)
371 dDistanceFromHyperplane(region[n],0)=0.0;
372 for (
int m=0; m<SB::ext_param_.actual_mtry_; m++)
374 dDistanceFromHyperplane(region[n],0)+=
375 features(region[n],m)*node.weights()[m];
378 for(
int n=0; n<region.oob_size(); n++)
380 dDistanceFromHyperplane(region.oob_begin()[n],0)=0.0;
381 for (
int m=0; m<SB::ext_param_.actual_mtry_; m++)
383 dDistanceFromHyperplane(region.oob_begin()[n],0)+=
384 features(region.oob_begin()[n],m)*node.weights()[m];
389 bgfunc(dDistanceFromHyperplane,
391 region.begin(), region.end(),
392 region.classCounts());
399 node.intercept() = bgfunc.min_threshold_;
402 childRegions[0].classCounts() = bgfunc.bestCurrentCounts[0];
403 childRegions[1].classCounts() = bgfunc.bestCurrentCounts[1];
404 childRegions[0].classCountsIsValid =
true;
405 childRegions[1].classCountsIsValid =
true;
408 childRegions[0].setRange( region.begin() , region.begin() + bgfunc.min_index_ );
409 childRegions[0].rule = region.rule;
410 childRegions[0].rule.push_back(std::make_pair(1, 1.0));
411 childRegions[1].setRange( region.begin() + bgfunc.min_index_ , region.end() );
412 childRegions[1].rule = region.rule;
413 childRegions[1].rule.push_back(std::make_pair(1, 1.0));
418 std::sort(region.oob_begin(), region.oob_end(),
419 SortSamplesByDimensions< MultiArray<2, double> > (dDistanceFromHyperplane, 0));
423 for(nOOBindx=0; nOOBindx<region.oob_size(); nOOBindx++)
425 if(dDistanceFromHyperplane(region.oob_begin()[nOOBindx],0)>=node.intercept())
429 childRegions[0].set_oob_range( region.oob_begin() , region.oob_begin() + nOOBindx );
430 childRegions[1].set_oob_range( region.oob_begin() + nOOBindx , region.oob_end() );
436 return i_HyperplaneNode;
446 #endif // VIGRA_RANDOM_FOREST_RIDGE_SPLIT_H
static double impurity(Array const &hist, double total)
Definition: rf_split.hxx:443
RidgeSplit< BestGiniOfColumn< GiniCriterion > > GiniRidgeSplit
Definition: rf_ridge_split.hxx:442
MultiArrayView< 2, T, C > columnVector(MultiArrayView< 2, T, C > const &m, MultiArrayIndex d)
Definition: matrix.hxx:727
MultiArrayShape< actual_dimension >::type difference_type
Definition: multi_array.hxx:739
void set_external_parameters(ProblemSpec< T > const &in)
Definition: rf_split.hxx:112
int findBestSplit(MultiArrayView< 2, T, C >, MultiArrayView< 2, T2, C2 >, Region, ArrayVector< Region >, Random)
Definition: rf_split.hxx:150
bool ridgeRegressionSeries(MultiArrayView< 2, T, C1 > const &A, MultiArrayView< 2, T, C2 > const &b, MultiArrayView< 2, T, C3 > &x, Array const &lambda)
Definition: regression.hxx:304
bool closeAtTolerance(T1 l, T2 r, typename PromoteTraits< T1, T2 >::Promote epsilon)
Tolerance based floating-point equality.
Definition: mathutil.hxx:1638
int makeTerminalNode(MultiArrayView< 2, T, C >, MultiArrayView< 2, T2, C2 >, Region ®ion, Random)
Definition: rf_split.hxx:168
SquareRootTraits< FixedPoint< IntBits, FracBits > >::SquareRootResult sqrt(FixedPoint< IntBits, FracBits > v)
square root.
Definition: fixedpoint.hxx:616