Skip to main content

trade-learn Python Package

Project description

trade-learn:Building Trading Strategies in Python with Machine Learning

trade-learn is a machine learning strategy development toolkit based on alphalens, backtrader, pyfolio, and quantstats. It provides a complete strategy development process.     [ 中文版介绍 ]

The functions it gives including factor collection, factor processing, factor evaluation, causal analysis, model definition, and strategy backtesting, and supports visualization results saved as HTML files for sharing.

img

Summary of visualizations:

img

Key Features

  1. Integrated with strategy development components from the Quantopian open-source platform, such as empyrical, alphalens, and pyfolio toolkits.
  2. Provides stock quotes from "Yahoo Finance" and corresponding factor calculation formulas, including alpha101 and alpha191 factor sets.
  3. Provides stock quotes from "Tongdaxin Trading Software" and 30 verified technical indicators (tdx30), directly usable on the Tongdaxin platform.
  4. Signal-driven trading strategies with multiple templates to quickly build and backtest strategies, supporting both speculative and portfolio strategies.
  5. Causal graph construction and causal feature selection algorithms, and extend the gplearn function library to achieve "feature derivation" for time-series data.
  6. Exploratory analysis and optimal model selection tools to quickly preview data set patterns and common models' performance on the data set.
  7. Trimmed backtrader backtesting framework to reduce unnecessary dependencies and optimize backtest results for HTML display, providing more user-friendly interactive visualization.
  8. The entire strategy building process forms a complete loop for machine learning strategy development without introducing additional third-party packages except for model definition.

Download

pip install trade-learn
pip install git+https://github.com/MuuYesen/trade-learn.git@master

Usage Template

from tradelearn.trader.signal import Signal
from tradelearn.strategy.backtest.single import LongBacktest

# Data retrieval
raw_data, base_line = "Target stock data", "Benchmark stock data"

# Define backtest start and end dates
bt_begin_date, bt_end_date = "Backtest start date", "Backtest end date"

# Define Signal class
class Example(Signal):

    def __init__(self, stockid, raw_data, bt_begin_date, bt_end_date, param_dict):
        signal_df = "Computed signal series containing True, False, and np.NAN values, with dates set as index"
        
        self.set_signal(signal_df)

# Signal class parameter dictionary
param_dict = {'fea_list': "Set of variable names used to generate signals"}

# Run backtest
res = LongBacktest.run(Example, param_dict, raw_data, base_line, bt_begin_date, bt_end_date)

Simple Example

Using volume and price indicators for single stock trading

from tradelearn.query import Query  # Import data query module
from tradelearn.trader.signal import Signal  # Import strategy signal class
from tradelearn.strategy.backtest.single import LongBacktest  # Import single stock backtest module
from tradelearn.strategy.evaluate import Evaluate  # Import strategy evaluation module

import numpy as np


if __name__ == '__main__':
    
    # Define data start and end dates
    tn_begin_date = '2017-01-01'
    tn_end_date = '2022-06-22'

    # Query historical data for stock 600520 as the benchmark
    baseline = Query.history_ohlc(symbol='600520', start=tn_begin_date, end=tn_end_date, adjust='hfq', engine='tdx')

    # Retrieve raw data and add labels
    rawdata = Query.history_ohlc(symbol='600520', start=tn_begin_date, end=tn_end_date, adjust='hfq', engine='tdx')
    rawdata['label'] = rawdata['close'].pct_change(periods=5).shift(-1).map(lambda x: 1 if x > 0 else -1)

    # Define backtest start and end dates
    bt_begin_date = '2020-01-01'
    bt_end_date = '2022-06-22'
    
    # Define RSI signal class
    class RSI(Signal):

        def __init__(self, stockid, raw_data, bt_begin_date, bt_end_date, param_dict):
            
            indi = Query.tec_indicator(raw_data, ['RSI']) # Calculate Relative Strength Index (RSI)

            # Generate signals for the entire period
            def signal(x):
                if x < 20:
                    return True
                if x > 40:
                    return False
                return np.NAN
            indi = indi.set_index('date').applymap(signal)

            # Retain signals for the backtest period
            bt_indi = indi.query(f"date >= '{bt_begin_date}' and date < '{bt_end_date}'")

            self.set_signal(bt_indi)
    
    param_dict = {}
    
    # Run backtest
    res = LongBacktest.run(RSI, param_dict, rawdata, baseline, bt_begin_date, bt_end_date)

    # Analyze backtest results
    Evaluate.analysis_report(res, baseline, engine='quantstats')

Using machine learning models to build a portfolio

from tradelearn.query import Query  # Import data query module
from tradelearn.trader.signal import Signal  # Import strategy signal class
from tradelearn.strategy.backtest.fund import LongBacktest  # Import portfolio backtest module
from tradelearn.strategy.evaluate import Evaluate  # Import strategy evaluation module

import pandas as pd
from dateutil.relativedelta import relativedelta

from sklearn.ensemble import RandomForestClassifier  # Import Random Forest classifier


if __name__ == '__main__':
    
    # Define data start and end dates
    tn_begin_date = '2017-01-01'
    tn_end_date = '2022-06-22'

    # Query historical data for the Shanghai Composite Index as the benchmark
    baseline = Query.history_ohlc(symbol='000001.SS', start=tn_begin_date, end=tn_end_date, engine='yahoo')

    rawdata = None
    # Loop to query historical data for multiple stocks and process
    for i in range(10):
        temp = Query.history_ohlc(symbol='60052' + str(i), start=tn_begin_date, end=tn_end_date, adjust='hfq', engine='tdx')
        if temp is None:
            continue

        # Label the data
        temp['label'] = temp['close'].pct_change(periods=5).shift(-1).map(lambda x: 1 if x > 0 else -1)
        rawdata = pd.concat([rawdata, temp], axis=0)

    # Define backtest start and end dates
    bt_begin_date = '2020-01-01'
    bt_end_date = '2022-06-22'
    
    # Define Random Forest indicator class and use rolling prediction to generate trading signals
    class RandomForest(Signal):

        model_dict = {}  # Model dictionary

        def __init__(self, stockid, raw_data, bt_begin_date, bt_end_date, param_dict):
            fea_list = param_dict['fea_list']
            
            if not RandomForest.model_dict:
                # Build Random Forest models and save to the model dictionary
                for date in pd.date_range(start=bt_begin_date, end=bt_end_date, freq='12MS'):
                    bt_train_data = raw_data.query(f"date >= '{date - relativedelta(months=12 * 3)}' and date < '{date}'")
                    bt_x_train, bt_y_train = bt_train_data[fea_list], bt_train_data['label']

                    model = RandomForestClassifier(random_state=42, n_jobs=-1)
                    model.fit(bt_x_train, bt_y_train)
                    RandomForest.model_dict[date.year] = model

            # Use models for prediction
            indi_df = None
            for date in pd.date_range(start=bt_begin_date, end=bt_end_date, freq='12MS'):
                pos_data = raw_data.query(f"code == '{stockid}' and date >= '{date}' and date < '{date + relativedelta(months=12 * 1)}'")
                bt_x_test = pos_data.set_index(['date'])[fea_list]
                pre_proba = RandomForest.model_dict[date.year].predict_proba(bt_x_test)[:, 1]
                indi_df = pd.concat([indi_df, pd.DataFrame(pre_proba, index=pos_data['date'])])

            self.set_signal(indi_df)

    # Feature list, excluding labels and code and date columns
    fea_list = rawdata.columns.drop(['label', 'code', 'date']).tolist()
    param_dict = {'fea_list': fea_list}
    
    # Run backtest
    res = LongBacktest.run(RandomForest, param_dict, rawdata, baseline, bt_begin_date, bt_end_date)
    
    # Analyze backtest results
    Evaluate.analysis_report(res, baseline, engine='quantstats')

Method Guide

Retrieving Raw Data

from tradelearn.query import Query

rawdata = Query.history_ohlc(symbol='600520', start='2017-01-01', end='2022-06-22', adjust='hfq',engine='tdx')
Parameter Name Data Type Notes
symbol string Stock ticker
start string Start date
end string End date
adjust string Adjustment method, can choose forward or backward adjustment, corresponding to 'qfq' and 'hfq' respectively
engine string Third-party data source, can choose Yahoo Finance or Tongdaxin, corresponding to 'yahoo' and 'tdx' respectively

Factor Generation

from tradelearn.query import Query

res = Query.alphas101(stock_data=rawdata, alpha_name=['alpha001'])
res = Query.alphas191(stock_data=rawdata, alpha_name=['alpha001'])
res = Query.tec_indicator(stock_data=rawdata, alpha_name=['ATR', 'RSI'])
Parameter Name Data Type Notes
stock_data DataFrame Target market data, required to have columns: open, low, high, close, volume, vwap
alpha_name list List of factor or indicator names

Exploratory Analysis

from tradelearn.strategy.preprocess.explore import Explore

Explore.analysis_report(data=rawdata, filename='res/explore.html')
Parameter Name Data Type Notes
data DataFrame Target market data
filename string Path and name of the saved HTML file, optional

Factor Derivation

from tradelearn.strategy.preprocess.derive import Derive

res = Derive.generic_generate(data=rawdata)
Parameter Name Data Type Notes
data DataFrame Target market data
f_col list List of variable names derived from participating factors, and evaluated by Sharpe metrics. default to all variables except code, date, and label
n_alpha int Count derived from the final factor
random_status int Random number seed, if not set, each execution will appear different results

Single Factor Test

from tradelearn.strategy.examine import Examine

Examine.single_factor(data=data, col='alpha001_101', filename='res/examine.html')
Parameter Name Data Type Notes
data DataFrame Target market data, required to have two or more stocks
col string Target factor name
filename string Path and name of the saved HTML file, optional

Multi-Factor Comparison

from tradelearn.strategy.examine import Examine

res = Examine.factor_compare(data=data, f_col=None, ind=None, cir=None)
Parameter Name Data Type Notes
data DataFrame Target market data, required to have two or more stocks
f_col string List of factor names to compare, if None, all variables will be compared
ind string Industry field name for t-test calculation, optional
cir string Market capitalization field name for t-test calculation, optional

Causal Feature Selection

from tradelearn.causal.blanket import Blanket

Blanket.fit_causal(data=rawdata, method='iamb', target_name='volume', is_discrete=False)
Parameter Name Data Type Notes
data DataFrame Target market data
method string Selected causal feature selection algorithm, options are 'iamb' and 'pcmb'
target string Dependent variable name
alpha float Confidence level, generally set to 0.05 or 0.01
is_discrete bool If data is discrete, set to True

Causal Graph Construction

from tradelearn.causal.graph import Graph

Graph.fit_causal(data=rawdata, method='pc', is_discrete=False, filename='res/pc.png')
Parameter Name Data Type Notes
data DataFrame Target market data
method string Selected causal graph construction algorithm, options are 'pc' and 'ges'
is_discrete bool If data is discrete, set to True
filename string Path and name of the saved causal graph, optional

Optimal Model Selection

from tradelearn.automl import AutoML

model = AutoML.lazy_predict(data=data)
Parameter Name Data Type Notes
data DataFrame Target market data

Backtest Validation

from tradelearn.strategy.backtest.single import LongBacktest  # Template call for single target speculative trading strategy, choose one of two
from tradelearn.strategy.backtest.fund import LongBacktest    # Template call for multi-target portfolio strategy, choose one of two

res = LongBacktest.run(model_class=Example, param_dict=param_dict, raw_data=rawdata, base_line=baseline,
                       begin_date=bt_begin_date, end_date=bt_end_date, show_source=True)
Parameter Name Data Type Notes
model_class Signal Implementation of signal class, user-defined
param_dict dict Dictionary of parameters to pass to signal class
raw_data DataFrame Target market data
base_line DataFrame Baseline market data
begin_date string Start date of backtest
end_date string End date of backtest
show_source bool Whether to show strategy source code in HTML file, default is True

Strategy Evaluation

from tradelearn.strategy.evaluate import Evaluate

Evaluate.analysis_report(strat=res, baseline=baseline, filename='./evaluate.html', engine='quantstats')
Parameter Name Data Type Notes
strat dict Data dictionary returned by LongBacktest.run()
baseline DataFrame Baseline market data
filename string Path and name of the generated HTML file, optional
engine string Evaluation engine, options are pyfolio or quantstats, corresponding to 'pyfolio' and 'quantstats' respectively

Acknowledgements

Contact Information

WeChat Official Account:知守溪的收纳屋           Email:muyes88@gmail.com

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distribution

trade_learn-0.1.1.8-py3-none-any.whl (659.8 kB view details)

Uploaded Python 3

File details

Details for the file trade_learn-0.1.1.8-py3-none-any.whl.

File metadata

File hashes

Hashes for trade_learn-0.1.1.8-py3-none-any.whl
Algorithm Hash digest
SHA256 e7d4ec3e8d51932adc18321c3b75c5120f6e4569931b5b8cdf6c79ea84a1a9dc
MD5 41bf5b0eed57ffd10854a3a9e89a8339
BLAKE2b-256 e70515d2e47d477aa365c7ec5b0f3e9e5b85a25ecf182a2677edc4bbccb808a4

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page