35#ifndef VIGRA_RANDOM_FOREST_SPLIT_HXX
36#define VIGRA_RANDOM_FOREST_SPLIT_HXX
42#include "../mathutil.hxx"
43#include "../array_vector.hxx"
44#include "../sized_int.hxx"
45#include "../matrix.hxx"
46#include "../random.hxx"
47#include "../functorexpression.hxx"
48#include "rf_nodeproxy.hxx"
50#include "rf_region.hxx"
59class CompileTimeError;
69 static void exec(Iter , Iter )
74 class Normalise<ClassificationTag>
78 static void exec (Iter begin, Iter end)
80 double bla = std::accumulate(begin, end, 0.0);
81 for(
int ii = 0; ii < end - begin; ++ii)
82 begin[ii] = begin[ii]/bla ;
115 t_data.push_back(
in.column_count_);
116 t_data.push_back(
in.class_count_);
124 int classCount()
const
126 return int(t_data[1]);
129 int featureCount()
const
131 return int(t_data[0]);
149 template<
class T,
class C,
class T2,
class C2,
class Region,
class Random>
167 template<
class T,
class C,
class T2,
class C2,
class Region,
class Random>
175 if(ext_param_.class_weights_.
size() !=
region.classCounts().
size())
185 ext_param_.class_weights_.
begin(),
186 ret.prob_begin(), std::multiplies<double>());
188 detail::Normalise<RF_Tag>::exec(
ret.prob_begin(),
ret.prob_end());
192 return e_ConstProbNode;
200template<
class DataMatrix>
220 void setThreshold(
double value)
227 return data_(
l, sortColumn_) < data_(r, sortColumn_);
231 return data_(
l, sortColumn_) < thresVal_;
235template<
class DataMatrix>
236class DimensionNotEqual
251 sortColumn_ = sortColumn;
256 return data_(l, sortColumn_) != data_(r, sortColumn_);
260template<
class DataMatrix>
261class SortSamplesByHyperplane
263 DataMatrix
const & data_;
264 Node<i_HyperplaneNode>
const & node_;
268 SortSamplesByHyperplane(DataMatrix
const & data,
269 Node<i_HyperplaneNode>
const & node)
279 double result_l = -1 * node_.intercept();
280 for(
int ii = 0; ii < node_.columns_size(); ++ii)
282 result_l +=
rowVector(data_, l)[node_.columns_begin()[ii]]
283 * node_.weights()[ii];
290 return (*
this)[l] < (*this)[r];
304template <
class DataSource,
class CountArray>
327 counts_[labels_[
l]] +=1;
342 double operator[](
size_t)
const
362 template<
class Array,
class Array2>
365 double total = 1.0)
const
372 template<
class Array>
380 template<
class Array>
388 template<
class Array,
class Array2>
400 entropy = 0 - weights[0]*
p0*std::log(
p0) - weights[1]*p1*std::log(p1);
404 for(
int ii = 0;
ii < class_count; ++
ii)
406 double w = weights[
ii];
424 template<
class Array,
class Array2>
427 double total = 1.0)
const
434 template<
class Array>
442 template<
class Array>
450 template<
class Array,
class Array2>
460 double w = weights[0] * weights[1];
465 for(
int ii = 0;
ii < class_count; ++
ii)
467 double w = weights[
ii];
476template <
class DataSource,
class Impurity= GiniCriterion>
480 DataSource
const & labels_;
481 ArrayVector<double> counts_;
482 ArrayVector<double>
const class_weights_;
483 double total_counts_;
489 ImpurityLoss(DataSource
const & labels,
490 ProblemSpec<T>
const & ext_)
492 counts_(ext_.class_count_, 0.0),
493 class_weights_(ext_.class_weights_),
503 template<
class Counts>
504 double increment_histogram(Counts
const & counts)
506 std::transform(counts.begin(), counts.end(),
507 counts_.begin(), counts_.begin(),
508 std::plus<double>());
509 total_counts_ = std::accumulate( counts_.begin(),
512 return impurity_(counts_, class_weights_, total_counts_);
515 template<
class Counts>
516 double decrement_histogram(Counts
const & counts)
518 std::transform(counts.begin(), counts.end(),
519 counts_.begin(), counts_.begin(),
520 std::minus<double>());
521 total_counts_ = std::accumulate( counts_.begin(),
524 return impurity_(counts_, class_weights_, total_counts_);
528 double increment(Iter begin, Iter end)
530 for(Iter iter = begin; iter != end; ++iter)
532 counts_[labels_(*iter, 0)] +=1.0;
535 return impurity_(counts_, class_weights_, total_counts_);
539 double decrement(Iter
const & begin, Iter
const & end)
541 for(Iter iter = begin; iter != end; ++iter)
543 counts_[labels_(*iter,0)] -=1.0;
546 return impurity_(counts_, class_weights_, total_counts_);
549 template<
class Iter,
class Resp_t>
550 double init (Iter , Iter , Resp_t resp)
553 std::copy(resp.begin(), resp.end(), counts_.begin());
554 total_counts_ = std::accumulate(counts_.begin(), counts_.end(), 0.0);
555 return impurity_(counts_,class_weights_, total_counts_);
558 ArrayVector<double>
const & response()
566 template <
class DataSource>
567 class RegressionForestCounter
570 typedef MultiArrayShape<2>::type Shp;
571 DataSource
const & labels_;
572 ArrayVector <double> mean_;
573 ArrayVector <double> variance_;
574 ArrayVector <double> tmp_;
579 RegressionForestCounter(DataSource
const & labels,
580 ProblemSpec<T>
const & ext_)
583 mean_(ext_.response_size_, 0.0),
584 variance_(ext_.response_size_, 0.0),
585 tmp_(ext_.response_size_),
590 double increment (Iter begin, Iter end)
592 for(Iter iter = begin; iter != end; ++iter)
595 for(
unsigned int ii = 0; ii < mean_.size(); ++ii)
596 tmp_[ii] = labels_(*iter, ii) - mean_[ii];
597 double f = 1.0 / count_,
599 for(
unsigned int ii = 0; ii < mean_.size(); ++ii)
600 mean_[ii] += f*tmp_[ii];
601 for(
unsigned int ii = 0; ii < mean_.size(); ++ii)
602 variance_[ii] += f1*sq(tmp_[ii]);
604 double res = std::accumulate(variance_.begin(),
607 std::plus<double>());
613 double decrement (Iter begin, Iter end)
615 for(Iter iter = begin; iter != end; ++iter)
624 for(
unsigned int ii = 0; ii < mean_.size(); ++ii)
627 for(Iter iter = begin; iter != end; ++iter)
629 mean_[ii] += labels_(*iter, ii);
633 for(Iter iter = begin; iter != end; ++iter)
635 variance_[ii] += (labels_(*iter, ii) - mean_[ii])*(labels_(*iter, ii) - mean_[ii]);
638 double res = std::accumulate(variance_.begin(),
641 std::plus<double>());
647 template<
class Iter,
class Resp_t>
648 double init (Iter begin, Iter end, Resp_t )
651 return this->increment(begin, end);
656 ArrayVector<double>
const & response()
670template <
class DataSource>
671class RegressionForestCounter2
674 typedef MultiArrayShape<2>::type Shp;
675 DataSource
const & labels_;
676 ArrayVector <double> mean_;
677 ArrayVector <double> variance_;
678 ArrayVector <double> tmp_;
682 RegressionForestCounter2(DataSource
const & labels,
683 ProblemSpec<T>
const & ext_)
686 mean_(ext_.response_size_, 0.0),
687 variance_(ext_.response_size_, 0.0),
688 tmp_(ext_.response_size_),
693 double increment (Iter begin, Iter end)
695 for(Iter iter = begin; iter != end; ++iter)
698 for(
int ii = 0; ii < mean_.size(); ++ii)
699 tmp_[ii] = labels_(*iter, ii) - mean_[ii];
700 double f = 1.0 / count_,
702 for(
int ii = 0; ii < mean_.size(); ++ii)
703 mean_[ii] += f*tmp_[ii];
704 for(
int ii = 0; ii < mean_.size(); ++ii)
705 variance_[ii] += f1*sq(tmp_[ii]);
707 double res = std::accumulate(variance_.begin(),
711 /((count_ == 1)? 1:(count_ -1));
717 double decrement (Iter begin, Iter end)
719 for(Iter iter = begin; iter != end; ++iter)
721 double f = 1.0 / count_,
723 for(
int ii = 0; ii < mean_.size(); ++ii)
724 mean_[ii] = (mean_[ii] - f*labels_(*iter,ii))/(1-f);
725 for(
int ii = 0; ii < mean_.size(); ++ii)
726 variance_[ii] -= f1*sq(labels_(*iter,ii) - mean_[ii]);
729 double res = std::accumulate(variance_.begin(),
733 /((count_ == 1)? 1:(count_ -1));
783 template<
class Iter,
class Resp_t>
784 double init (Iter begin, Iter end, Resp_t resp)
787 return this->increment(begin, end, resp);
791 ArrayVector<double>
const & response()
804template<
class Tag,
class Datatyp>
810template<
class Datatype>
811struct LossTraits<GiniCriterion, Datatype>
813 typedef ImpurityLoss<Datatype, GiniCriterion> type;
816template<
class Datatype>
817struct LossTraits<EntropyCriterion, Datatype>
819 typedef ImpurityLoss<Datatype, EntropyCriterion> type;
822template<
class Datatype>
823struct LossTraits<LSQLoss, Datatype>
825 typedef RegressionForestCounter<Datatype> type;
830template<
class LineSearchLossTag>
837 std::ptrdiff_t min_index_;
838 double min_threshold_;
847 class_weights_(
ext.class_weights_),
850 bestCurrentCounts[0].resize(
ext.class_count_);
851 bestCurrentCounts[1].resize(
ext.class_count_);
856 class_weights_ =
ext.class_weights_;
858 bestCurrentCounts[0].resize(
ext.class_count_);
859 bestCurrentCounts[1].resize(
ext.class_count_);
898 std::sort(begin, end,
908 min_threshold_ = *begin;
913 I_Iter next = std::adjacent_find(iter, end,
comp);
917 double lr = right.decrement(iter, next + 1);
918 double ll = left.increment(iter , next + 1);
921#ifdef CLASSIFIER_TEST
924 if(
loss < min_gini_ )
927 bestCurrentCounts[0] = left.response();
928 bestCurrentCounts[1] = right.response();
929#ifdef CLASSIFIER_TEST
930 min_gini_ =
loss < min_gini_?
loss : min_gini_;
934 min_index_ = next - begin +1 ;
935 min_threshold_ = (
double(column(*next,0)) +
double(column(*(next +1), 0)))/2.0;
938 next = std::adjacent_find(iter, end,
comp);
945 template<
class DataSource_t,
class Iter,
class Array>
965 template<
class Region,
class LabelT>
966 static void exec(Region & , LabelT & )
971 struct Correction<ClassificationTag>
973 template<
class Region,
class LabelT>
974 static void exec(Region & region, LabelT & labels)
976 if(std::accumulate(region.classCounts().begin(),
977 region.classCounts().end(), 0.0) != region.size())
979 RandomForestClassCounter< LabelT,
980 ArrayVector<double> >
981 counter(labels, region.classCounts());
982 std::for_each( region.begin(), region.end(), counter);
983 region.classCountsIsValid =
true;
992template<
class ColumnDecisionFunctor,
class Tag = ClassificationTag>
1001 ColumnDecisionFunctor bgfunc;
1003 double region_gini_;
1010 double minGini()
const
1012 return min_gini_[bestSplitIndex];
1014 int bestSplitColumn()
const
1016 return splitColumns[bestSplitIndex];
1018 double bestSplitThreshold()
const
1020 return min_thresholds_[bestSplitIndex];
1026 SB::set_external_parameters(
in);
1027 bgfunc.set_external_parameters( SB::ext_param_);
1028 int featureCount_ = SB::ext_param_.column_count_;
1029 splitColumns.resize(featureCount_);
1030 for(
int k=0;
k<featureCount_; ++
k)
1031 splitColumns[
k] =
k;
1032 min_gini_.resize(featureCount_);
1033 min_indices_.resize(featureCount_);
1034 min_thresholds_.resize(featureCount_);
1038 template<
class T,
class C,
class T2,
class C2,
class Region,
class Random>
1046 typedef typename Region::IndexIterator IndexIterator;
1049 std::cerr <<
"SplitFunctor::findBestSplit(): stackentry with 0 examples encountered\n"
1050 "continuing learning process....";
1053 detail::Correction<Tag>::exec(
region, labels);
1057 region_gini_ = bgfunc.loss_of_region(labels,
1061 if(region_gini_ <= SB::ext_param_.precision_)
1065 for(
int ii = 0;
ii < SB::ext_param_.actual_mtry_; ++
ii)
1066 std::swap(splitColumns[
ii],
1067 splitColumns[
ii+
randint(features.shape(1) -
ii)]);
1072 int num2try = features.shape(1);
1076 bgfunc(columnVector(features, splitColumns[
k]),
1080 min_gini_[
k] = bgfunc.min_gini_;
1081 min_indices_[
k] = bgfunc.min_index_;
1082 min_thresholds_[
k] = bgfunc.min_threshold_;
1083#ifdef CLASSIFIER_TEST
1091 childRegions[0].classCounts() = bgfunc.bestCurrentCounts[0];
1092 childRegions[1].classCounts() = bgfunc.bestCurrentCounts[1];
1097 num2try = SB::ext_param_.actual_mtry_;
1110 node.threshold() = min_thresholds_[bestSplitIndex];
1111 node.column() = splitColumns[bestSplitIndex];
1115 sorter(features, node.column(), node.threshold());
1121 childRegions[0].rule.push_back(std::make_pair(1, 1.0));
1124 childRegions[1].rule.push_back(std::make_pair(1, 1.0));
1126 return i_ThresholdNode;
1171 std::ptrdiff_t min_index_;
1172 double min_threshold_;
1181 class_weights_(
ext.class_weights_),
1184 bestCurrentCounts[0].resize(
ext.class_count_);
1185 bestCurrentCounts[1].resize(
ext.class_count_);
1191 class_weights_ =
ext.class_weights_;
1193 bestCurrentCounts[0].resize(
ext.class_count_);
1194 bestCurrentCounts[1].resize(
ext.class_count_);
1207 std::sort(begin, end,
1210 LossTraits<LineSearchLossTag, DataSource_t>::type
LineSearchLoss;
1215 min_gini_ = NumericTraits<double>::max();
1216 min_index_ =
floor(
double(end - begin)/2.0);
1217 min_threshold_ = column[*(begin + min_index_)];
1219 sorter(column, 0, min_threshold_);
1233 min_threshold_ = column[*
part];
1235 min_gini_ = right.decrement(begin,
part)
1236 + left.increment(begin ,
part);
1238 bestCurrentCounts[0] = left.response();
1239 bestCurrentCounts[1] = right.response();
1241 min_index_ =
part - begin;
1244 template<
class DataSource_t,
class Iter,
class Array>
1251 LossTraits<LineSearchLossTag, DataSource_t>::type
LineSearchLoss;
1272 std::ptrdiff_t min_index_;
1273 double min_threshold_;
1284 class_weights_(
ext.class_weights_),
1288 bestCurrentCounts[0].resize(
ext.class_count_);
1289 bestCurrentCounts[1].resize(
ext.class_count_);
1295 class_weights_(
ext.class_weights_),
1299 bestCurrentCounts[0].resize(
ext.class_count_);
1300 bestCurrentCounts[1].resize(
ext.class_count_);
1306 class_weights_ =
ext.class_weights_;
1308 bestCurrentCounts[0].resize(
ext.class_count_);
1309 bestCurrentCounts[1].resize(
ext.class_count_);
1322 std::sort(begin, end,
1325 LossTraits<LineSearchLossTag, DataSource_t>::type
LineSearchLoss;
1331 min_gini_ = NumericTraits<double>::max();
1332 int tmp_pt = random.uniformInt(std::distance(begin, end));
1334 min_threshold_ = column[*(begin + min_index_)];
1336 sorter(column, 0, min_threshold_);
1350 min_threshold_ = column[*
part];
1352 min_gini_ = right.decrement(begin,
part)
1353 + left.increment(begin ,
part);
1355 bestCurrentCounts[0] = left.response();
1356 bestCurrentCounts[1] = right.response();
1358 min_index_ =
part - begin;
1361 template<
class DataSource_t,
class Iter,
class Array>
1368 LossTraits<LineSearchLossTag, DataSource_t>::type
LineSearchLoss;
Definition rf_split.hxx:832
void operator()(DataSourceF_t const &column, DataSource_t const &labels, I_Iter &begin, I_Iter &end, Array const ®ion_response)
Definition rf_split.hxx:892
Definition rf_split.hxx:357
static double impurity(Array const &hist, double total)
Definition rf_split.hxx:381
double operator()(Array const &hist, Array2 const &weights, double total=1.0) const
Definition rf_split.hxx:363
static double impurity(Array const &hist, Array2 const &weights, double total)
Definition rf_split.hxx:389
double operator()(Array const &hist, double total=1.0) const
Definition rf_split.hxx:373
Definition rf_split.hxx:419
static double impurity(Array const &hist, double total)
Definition rf_split.hxx:443
double operator()(Array const &hist, Array2 const &weights, double total=1.0) const
Definition rf_split.hxx:425
static double impurity(Array const &hist, Array2 const &weights, double total)
Definition rf_split.hxx:451
double operator()(Array const &hist, double total=1.0) const
Definition rf_split.hxx:435
Definition rf_nodeproxy.hxx:88
Class for a single RGB value.
Definition rgbvalue.hxx:128
Definition rf_split.hxx:306
Definition rf_split.hxx:202
Definition rf_split.hxx:93
int findBestSplit(MultiArrayView< 2, T, C >, MultiArrayView< 2, T2, C2 >, Region, ArrayVector< Region >, Random)
Definition rf_split.hxx:150
int makeTerminalNode(MultiArrayView< 2, T, C >, MultiArrayView< 2, T2, C2 >, Region ®ion, Random)
Definition rf_split.hxx:168
void set_external_parameters(ProblemSpec< T > const &in)
Definition rf_split.hxx:112
void reset()
Definition rf_split.hxx:137
Definition rf_split.hxx:994
void init(Iterator i, Iterator end)
Definition tinyvector.hxx:708
size_type size() const
Definition tinyvector.hxx:913
iterator end()
Definition tinyvector.hxx:864
iterator begin()
Definition tinyvector.hxx:861
Definition rf_split.hxx:1265
MultiArrayView< 2, T, C > rowVector(MultiArrayView< 2, T, C > const &m, MultiArrayIndex d)
Definition matrix.hxx:697
int floor(FixedPoint< IntBits, FracBits > v)
rounding down.
Definition fixedpoint.hxx:667
bool closeAtTolerance(T1 l, T2 r, typename PromoteTraits< T1, T2 >::Promote epsilon)
Tolerance based floating-point equality.
Definition mathutil.hxx:1638
std::ptrdiff_t MultiArrayIndex
Definition multi_fwd.hxx:60