forked from microsoft/qlib
-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Migrate amc4th training (microsoft#1316)
* Migrate amc4th training * Refine RL example scripts * Resolve PR comments Co-authored-by: luocy16 <[email protected]>
- Loading branch information
Showing
19 changed files
with
676 additions
and
50 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
This folder contains a simple example of how to run Qlib RL. It contains: | ||
|
||
``` | ||
. | ||
├── experiment_config | ||
│ ├── backtest # Backtest config | ||
│ └── training # Training config | ||
├── README.md # Readme (the current file) | ||
└── scripts # Scripts for data pre-processing | ||
``` | ||
|
||
## Data preparation | ||
|
||
Use [AzCopy](https://learn.microsoft.com/en-us/azure/storage/common/storage-use-azcopy-v10) to download data: | ||
|
||
``` | ||
azcopy copy https://qlibpublic.blob.core.windows.net/data/rl/qlib_rl_example_data ./ --recursive | ||
mv qlib_rl_example_data data | ||
``` | ||
|
||
The downloaded data will be placed at `./data`. The original data are in `data/csv`. To create all data needed by the case, run: | ||
|
||
``` | ||
bash scripts/data_pipeline.sh | ||
``` | ||
|
||
After the execution finishes, the `data/` directory should be like: | ||
|
||
``` | ||
data | ||
├── backtest_orders.csv | ||
├── bin | ||
├── csv | ||
├── pickle | ||
├── pickle_dataframe | ||
└── training_order_split | ||
``` | ||
|
||
## Run training | ||
|
||
Run: | ||
|
||
``` | ||
python ../../qlib/rl/contrib/train_onpolicy.py --config_path ./experiment_config/training/config.yml | ||
``` | ||
|
||
After training, checkpoints will be stored under `checkpoints/`. | ||
|
||
## Run backtest | ||
|
||
``` | ||
python ../../qlib/rl/contrib/backtest.py --config_path ./experiment_config/backtest/config.py | ||
``` | ||
|
||
The backtest workflow will use the trained model in `checkpoints/`. The backtest summary can be found in `outputs/`. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
_base_ = ["./twap.yml"] | ||
|
||
strategies = { | ||
"_delete_": True, | ||
"30min": { | ||
"class": "TWAPStrategy", | ||
"module_path": "qlib.contrib.strategy.rule_strategy", | ||
"kwargs": {}, | ||
}, | ||
"1day": { | ||
"class": "SAOEIntStrategy", | ||
"module_path": "qlib.rl.order_execution.strategy", | ||
"kwargs": { | ||
"state_interpreter": { | ||
"class": "FullHistoryStateInterpreter", | ||
"module_path": "qlib.rl.order_execution.interpreter", | ||
"kwargs": { | ||
"max_step": 8, | ||
"data_ticks": 240, | ||
"data_dim": 6, | ||
"processed_data_provider": { | ||
"class": "PickleProcessedDataProvider", | ||
"module_path": "qlib.rl.data.pickle_styled", | ||
"kwargs": { | ||
"data_dir": "./data/pickle_dataframe/feature", | ||
}, | ||
}, | ||
}, | ||
}, | ||
"action_interpreter": { | ||
"class": "CategoricalActionInterpreter", | ||
"module_path": "qlib.rl.order_execution.interpreter", | ||
"kwargs": { | ||
"values": 14, | ||
"max_step": 8, | ||
}, | ||
}, | ||
"network": { | ||
"class": "Recurrent", | ||
"module_path": "qlib.rl.order_execution.network", | ||
"kwargs": {}, | ||
}, | ||
"policy": { | ||
"class": "PPO", | ||
"module_path": "qlib.rl.order_execution.policy", | ||
"kwargs": { | ||
"lr": 1.0e-4, | ||
"weight_file": "./checkpoints/latest.pth", | ||
}, | ||
}, | ||
}, | ||
}, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
order_file: ./data/backtest_orders.csv | ||
start_time: "9:45" | ||
end_time: "14:44" | ||
qlib: | ||
provider_uri_1min: ./data/bin | ||
feature_root_dir: ./data/pickle | ||
feature_columns_today: [ | ||
"$open", "$high", "$low", "$close", "$vwap", "$volume", | ||
] | ||
feature_columns_yesterday: [ | ||
"$open_v1", "$high_v1", "$low_v1", "$close_v1", "$vwap_v1", "$volume_v1", | ||
] | ||
exchange: | ||
limit_threshold: ['$close == 0', '$close == 0'] | ||
deal_price: ["If($close == 0, $vwap, $close)", "If($close == 0, $vwap, $close)"] | ||
volume_threshold: | ||
all: ["cum", "0.2 * DayCumsum($volume, '9:45', '14:44')"] | ||
buy: ["current", "$close"] | ||
sell: ["current", "$close"] | ||
strategies: {} # Placeholder | ||
concurrency: 5 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
simulator: | ||
time_per_step: 30 | ||
vol_limit: null | ||
env: | ||
concurrency: 1 | ||
parallel_mode: dummy | ||
action_interpreter: | ||
class: CategoricalActionInterpreter | ||
kwargs: | ||
values: 14 | ||
max_step: 8 | ||
module_path: qlib.rl.order_execution.interpreter | ||
state_interpreter: | ||
class: FullHistoryStateInterpreter | ||
kwargs: | ||
data_dim: 6 | ||
data_ticks: 240 | ||
max_step: 8 | ||
processed_data_provider: | ||
class: PickleProcessedDataProvider | ||
module_path: qlib.rl.data.pickle_styled | ||
kwargs: | ||
data_dir: ./data/pickle_dataframe/feature | ||
module_path: qlib.rl.order_execution.interpreter | ||
reward: | ||
class: PAPenaltyReward | ||
kwargs: | ||
penalty: 100.0 | ||
module_path: qlib.rl.order_execution.reward | ||
data: | ||
source: | ||
order_dir: ./data/training_order_split | ||
data_dir: ./data/pickle_dataframe/backtest | ||
total_time: 240 | ||
default_start_time: 0 | ||
default_end_time: 240 | ||
proc_data_dim: 6 | ||
num_workers: 0 | ||
queue_size: 20 | ||
network: | ||
class: Recurrent | ||
module_path: qlib.rl.order_execution.network | ||
policy: | ||
class: PPO | ||
kwargs: | ||
lr: 0.0001 | ||
module_path: qlib.rl.order_execution.policy | ||
runtime: | ||
seed: 42 | ||
use_cuda: false | ||
trainer: | ||
max_epoch: 2 | ||
repeat_per_collect: 5 | ||
earlystop_patience: 2 | ||
episode_per_collect: 20 | ||
batch_size: 16 | ||
val_every_n_epoch: 1 | ||
checkpoint_path: ./checkpoints | ||
checkpoint_every_n_iters: 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
|
||
import os | ||
import pickle | ||
import pandas as pd | ||
from tqdm import tqdm | ||
|
||
os.makedirs(os.path.join("data", "pickle_dataframe"), exist_ok=True) | ||
|
||
for tag in ("backtest", "feature"): | ||
df = pickle.load(open(os.path.join("data", "pickle", f"{tag}.pkl"), "rb")) | ||
df = pd.concat(list(df.values())).reset_index() | ||
df["date"] = df["datetime"].dt.date.astype("datetime64") | ||
instruments = sorted(set(df["instrument"])) | ||
|
||
os.makedirs(os.path.join("data", "pickle_dataframe", tag), exist_ok=True) | ||
for instrument in tqdm(instruments): | ||
cur = df[df["instrument"] == instrument].sort_values(by=["datetime"]) | ||
cur = cur.set_index(["instrument", "datetime", "date"]) | ||
pickle.dump(cur, open(os.path.join("data", "pickle_dataframe", tag, f"{instrument}.pkl"), "wb")) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
# Generate `bin` format data | ||
set -e | ||
python ../../scripts/dump_bin.py dump_all --csv_path ./data/csv --qlib_dir ./data/bin --include_fields open,close,high,low,vwap,volume --symbol_field_name symbol --date_field_name date --freq 1min | ||
|
||
# Generate pickle format data | ||
python scripts/gen_pickle_data.py -c scripts/pickle_data_config.yml | ||
if [ -e stat/ ]; then | ||
rm -r stat/ | ||
fi | ||
python scripts/collect_pickle_dataframe.py | ||
|
||
# Sample orders | ||
python scripts/gen_training_orders.py | ||
python scripts/gen_backtest_orders.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
|
||
import argparse | ||
import os | ||
import pandas as pd | ||
import numpy as np | ||
import pickle | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--seed", type=int, default=20220926) | ||
parser.add_argument("--num_order", type=int, default=10) | ||
args = parser.parse_args() | ||
|
||
np.random.seed(args.seed) | ||
|
||
path = os.path.join("data", "pickle", "backtesttest.pkl") # TODO: rename file | ||
df = pickle.load(open(path, "rb")).reset_index() | ||
df["date"] = df["datetime"].dt.date.astype("datetime64") | ||
|
||
instruments = sorted(set(df["instrument"])) | ||
df_list = [] | ||
for instrument in instruments: | ||
print(instrument) | ||
|
||
cur_df = df[df["instrument"] == instrument] | ||
|
||
dates = sorted(set([str(d).split(" ")[0] for d in cur_df["date"]])) | ||
|
||
n = args.num_order | ||
df_list.append( | ||
pd.DataFrame({ | ||
"date": sorted(np.random.choice(dates, size=n, replace=False)), | ||
"instrument": [instrument] * n, | ||
"amount": np.random.randint(low=3, high=11, size=n) * 100.0, | ||
"order_type": np.random.randint(low=0, high=2, size=n), | ||
}).set_index(["date", "instrument"]), | ||
) | ||
|
||
total_df = pd.concat(df_list) | ||
total_df.to_csv("data/backtest_orders.csv") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
|
||
import yaml | ||
import argparse | ||
import os | ||
from copy import deepcopy | ||
|
||
from qlib.contrib.data.highfreq_provider import HighFreqProvider | ||
|
||
loader = yaml.FullLoader | ||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("-c", "--config", type=str, default="config.yml") | ||
parser.add_argument("-d", "--dest", type=str, default=".") | ||
parser.add_argument("-s", "--split", type=str, choices=["none", "date", "stock", "both"], default="stock") | ||
args = parser.parse_args() | ||
|
||
conf = yaml.load(open(args.config), Loader=loader) | ||
|
||
for k, v in conf.items(): | ||
if isinstance(v, dict) and "path" in v: | ||
v["path"] = os.path.join(args.dest, v["path"]) | ||
provider = HighFreqProvider(**conf) | ||
|
||
# Gen dataframe | ||
if "feature_conf" in conf: | ||
feature = provider._gen_dataframe(deepcopy(provider.feature_conf)) | ||
if "backtest_conf" in conf: | ||
backtest = provider._gen_dataframe(deepcopy(provider.backtest_conf)) | ||
|
||
provider.feature_conf['path'] = os.path.splitext(provider.feature_conf['path'])[0] + '/' | ||
provider.backtest_conf['path'] = os.path.splitext(provider.backtest_conf['path'])[0] + '/' | ||
# Split by date | ||
if args.split == "date" or args.split == "both": | ||
provider._gen_day_dataset(deepcopy(provider.feature_conf), "feature") | ||
provider._gen_day_dataset(deepcopy(provider.backtest_conf), "backtest") | ||
|
||
# Split by stock | ||
if args.split == "stock" or args.split == "both": | ||
provider._gen_stock_dataset(deepcopy(provider.feature_conf), "feature") | ||
provider._gen_stock_dataset(deepcopy(provider.backtest_conf), "backtest") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
|
||
import argparse | ||
import os | ||
import pandas as pd | ||
import numpy as np | ||
import pickle | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--seed", type=int, default=20220926) | ||
parser.add_argument("--stock", type=str, default="AAPL") | ||
parser.add_argument("--train_size", type=int, default=10) | ||
parser.add_argument("--valid_size", type=int, default=2) | ||
parser.add_argument("--test_size", type=int, default=2) | ||
args = parser.parse_args() | ||
|
||
np.random.seed(args.seed) | ||
|
||
os.makedirs(os.path.join("data", "training_order_split"), exist_ok=True) | ||
|
||
for group, n in zip(("train", "valid", "test"), (args.train_size, args.valid_size, args.test_size)): | ||
path = os.path.join("data", "pickle", f"backtest{group}.pkl") | ||
df = pickle.load(open(path, "rb")).reset_index() | ||
df["date"] = df["datetime"].dt.date.astype("datetime64") | ||
|
||
dates = sorted(set([str(d).split(" ")[0] for d in df["date"]])) | ||
|
||
data_df = pd.DataFrame({ | ||
"date": sorted(np.random.choice(dates, size=n, replace=False)), | ||
"instrument": [args.stock] * n, | ||
"amount": np.random.randint(low=3, high=11, size=n) * 100.0, | ||
"order_type": [0] * n, | ||
}).set_index(["date", "instrument"]) | ||
|
||
os.makedirs(os.path.join("data", "training_order_split", group), exist_ok=True) | ||
pickle.dump(data_df, open(os.path.join("data", "training_order_split", group, f"{args.stock}.pkl"), "wb")) |
Oops, something went wrong.