00001
00008
#ifndef _TREE_H_
00009
#define _TREE_H_
00010
#include "librf/types.h"
00011
#include "librf/tree_node.h"
00012
#include <iostream>
00013
#include <vector>
00014
#include <set>
00015
#include <map>
00016
#include <math.h>
00017
using namespace std;
00018
00019
namespace librf {
00020
00021
class InstanceSet;
00022
class weight_list;
00023
class DiscreteDist;
00024
class Instance;
00025
00045 class Tree {
00046
public:
00048
Tree(istream& in);
00050
Tree(
const InstanceSet& set, weight_list* weights,
00051
int K, uchar max_depth=16,
int min_size = 1,
00052
float min_gain = 0,
unsigned int seed =0);
00053 ~
Tree();
00055
int predict(
const Instance& c)
const;
00057
int predict(
const InstanceSet& set,
int instance_no)
const;
00058
00060
float training_accuracy()
const;
00061
00062
float testing_accuracy(
const InstanceSet& testset)
const;
00063
float oob_accuracy()
const;
00064
void oob_cases(weight_list* counts, weight_list* correct)
const;
00065
00066
void variable_importance(map<int, float>* scores,
unsigned int* seed)
const;
00067
void print()
const;
00068
00069
00070
void grow();
00071
void write(ostream& o)
const;
00072
void read(istream& i);
00073
private:
00074
void copy_instances();
00075
void move_data(tree_node* n, uint16 split_attr, uint16 split_idx);
00076
void find_best_split(tree_node* n,
00077
const vector<int>& attrs,
00078
int* split_attr,
int* split_idx,
00079
float* split_point,
float* split_gain);
00080
void find_best_split_for_attr(tree_node* n,
00081
int attr,
00082
float prior,
00083
int* split_idx,
00084
float *split_point,
00085
float* best_gain);
00086
00087
00088
void mark_build(tree_node* n, uint16 start, uint16 size, uchar depth);
00089
void mark_terminal(tree_node* n);
00090
void mark_split(tree_node* n, uint16 split_attr,
float split_point);
00091
00092
void build_tree(
int min_size);
00093
int build_node(uint16 node_num, uint16 min_size);
00094
void print_node(
int n)
const;
00095
00096
00097
00098 uint32 parent(uint32 node_num)
const {
00099
return uint32(floor((node_num - 1) / 2.0));
00100 }
00101 uint32 left_child(uint32 node_num)
const {
00102
return 2 * node_num + 1;
00103 }
00104 uint32 right_child(uint32 node_num)
const {
00105
return 2 * node_num + 2;
00106 }
00107
00108
void permuteOOB(
int m,
double *x);
00109
00110 vector<tree_node> nodes_;
00111 vector<uint16> active_nodes_;
00112 set<uint16> vars_used_;
00113 uint16 terminal_nodes_;
00114 uint16 split_nodes_;
00115
00116
00117
const InstanceSet& set_;
00118
00119
00120
00121
00122
00123
00124
00125
00126 uint16** sorted_inum_;
00127
00128
00129
00130 weight_list* weight_list_;
00131
00132 uint16 max_depth_;
00133 uint16 K_;
00134 uint16 min_size_;
00135
float min_gain_;
00136
00137
int* temp;
00138 uchar* move_left;
00139 uint16 num_instances_;
00140 uint16 num_attributes_;
00141 uint16 stride_;
00142
unsigned int rand_seed_;
00143
00144
static const int kLeft;
00145
static const int kRight;
00146 };
00147
00148 }
00149
#endif