Skip to main content

Python Implementation of Bertsimas's paper Optimal Classification Trees.

Project description

# pyoptree
Python Optimal Tree

### Install
#### First install pyoptree through pip
```
pip3 install pyoptree
```

#### Then install solver (IMPORTANT!)
The user needs to have **IBM Cplex** or **Gurobi** installed on their computer, and make sure that **the executable has been added to PATH environment variable** (i.e. command `cplex` or `gurobi` can be run on terminal).

### Example
```python
import pandas as pd
from pyoptree.optree import OptimalHyperTreeModel, OptimalTreeModel

data = pd.DataFrame({
"index": ['A', 'C', 'D', 'E', 'F'],
"x1": [1, 2, 2, 2, 3],
"x2": [1, 2, 1, 0, 1],
"y": [1, 1, -1, -1, -1]
})
test_data = pd.DataFrame({
"index": ['A', 'B', 'C', 'D', 'E', 'F', 'G'],
"x1": [1, 1, 2, 2, 2, 3, 3],
"x2": [1, 2, 2, 1, 0, 1, 0],
"y": [1, 1, 1, -1, -1, -1, -1]
})
model = OptimalHyperTreeModel(["x1", "x2"], "y", tree_depth=2, N_min=1, alpha=0.1, solver_name="cplex")
model.train(data)

print(model.predict(test_data))
```

### Todos
1. Implement "Warm Start" to speed up the time to solve the Mixed Integer Linear Programming (MILP); (Done √)
2. Implement heuristics such as CART、Generic Algorithms to approximate the optimal solution rapidly (but may be loss of accuracy);

### Hyper-parameter Tuning Experience
Generally, there are three hyper-parameters that could be tuned for the Optimal Tree model:
- **tree_depth**: the depth of the tree.
- **N_min**: minimum number of samples in each leaf node if that node contains any sample.
- **alpha**: coefficient of the regularization term.

As experimenting with the Iris dataset on a 4-thread computer using Cplex 12.8, the results are listed as follows:

| 参数名称 | 参数值 | 求解消耗时间 | 其他参数设置 | 目标函数最优值 |
| ------ | ------ | ------ | ------ | ------ |
| N_min | 30 | 20.80 sec | tree_depth = 2, alpha = 0.1 | 3.5 |
| N_min | 10 | 3027.00 sec | tree_depth = 2, alpha = 0.1 | 3.5 |

From the results, we can get the following insights:
1. **N_min** has great impact on the solving time (the larger N_min, the quicker). But a larger N_min may cause the problem
infeasible. So care must be taken when tuning N_min.

More experiments' results would be added in the future.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pyoptree-0.1.3.tar.gz (9.0 kB view hashes)

Uploaded Source

Built Distribution

pyoptree-0.1.3-py3-none-any.whl (11.1 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page