Skip to content

Commit

Permalink
Migrate amc4th training (microsoft#1316)
Browse files Browse the repository at this point in the history
* Migrate amc4th training

* Refine RL example scripts

* Resolve PR comments

Co-authored-by: luocy16 <[email protected]>
  • Loading branch information
lihuoran and luocy16 authored Oct 19, 2022
1 parent f35d8a3 commit 54ef210
Show file tree
Hide file tree
Showing 19 changed files with 676 additions and 50 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ qlib/VERSION.txt
qlib/data/_libs/expanding.cpp
qlib/data/_libs/rolling.cpp
examples/estimator/estimator_example/
examples/rl/data/
examples/rl/checkpoints/
examples/rl/outputs/

*.egg-info/

Expand Down
55 changes: 55 additions & 0 deletions examples/rl/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
This folder contains a simple example of how to run Qlib RL. It contains:

```
.
├── experiment_config
│ ├── backtest # Backtest config
│ └── training # Training config
├── README.md # Readme (the current file)
└── scripts # Scripts for data pre-processing
```

## Data preparation

Use [AzCopy](https://learn.microsoft.com/en-us/azure/storage/common/storage-use-azcopy-v10) to download data:

```
azcopy copy https://qlibpublic.blob.core.windows.net/data/rl/qlib_rl_example_data ./ --recursive
mv qlib_rl_example_data data
```

The downloaded data will be placed at `./data`. The original data are in `data/csv`. To create all data needed by the case, run:

```
bash scripts/data_pipeline.sh
```

After the execution finishes, the `data/` directory should be like:

```
data
├── backtest_orders.csv
├── bin
├── csv
├── pickle
├── pickle_dataframe
└── training_order_split
```

## Run training

Run:

```
python ../../qlib/rl/contrib/train_onpolicy.py --config_path ./experiment_config/training/config.yml
```

After training, checkpoints will be stored under `checkpoints/`.

## Run backtest

```
python ../../qlib/rl/contrib/backtest.py --config_path ./experiment_config/backtest/config.py
```

The backtest workflow will use the trained model in `checkpoints/`. The backtest summary can be found in `outputs/`.
53 changes: 53 additions & 0 deletions examples/rl/experiment_config/backtest/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
_base_ = ["./twap.yml"]

strategies = {
"_delete_": True,
"30min": {
"class": "TWAPStrategy",
"module_path": "qlib.contrib.strategy.rule_strategy",
"kwargs": {},
},
"1day": {
"class": "SAOEIntStrategy",
"module_path": "qlib.rl.order_execution.strategy",
"kwargs": {
"state_interpreter": {
"class": "FullHistoryStateInterpreter",
"module_path": "qlib.rl.order_execution.interpreter",
"kwargs": {
"max_step": 8,
"data_ticks": 240,
"data_dim": 6,
"processed_data_provider": {
"class": "PickleProcessedDataProvider",
"module_path": "qlib.rl.data.pickle_styled",
"kwargs": {
"data_dir": "./data/pickle_dataframe/feature",
},
},
},
},
"action_interpreter": {
"class": "CategoricalActionInterpreter",
"module_path": "qlib.rl.order_execution.interpreter",
"kwargs": {
"values": 14,
"max_step": 8,
},
},
"network": {
"class": "Recurrent",
"module_path": "qlib.rl.order_execution.network",
"kwargs": {},
},
"policy": {
"class": "PPO",
"module_path": "qlib.rl.order_execution.policy",
"kwargs": {
"lr": 1.0e-4,
"weight_file": "./checkpoints/latest.pth",
},
},
},
},
}
21 changes: 21 additions & 0 deletions examples/rl/experiment_config/backtest/twap.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
order_file: ./data/backtest_orders.csv
start_time: "9:45"
end_time: "14:44"
qlib:
provider_uri_1min: ./data/bin
feature_root_dir: ./data/pickle
feature_columns_today: [
"$open", "$high", "$low", "$close", "$vwap", "$volume",
]
feature_columns_yesterday: [
"$open_v1", "$high_v1", "$low_v1", "$close_v1", "$vwap_v1", "$volume_v1",
]
exchange:
limit_threshold: ['$close == 0', '$close == 0']
deal_price: ["If($close == 0, $vwap, $close)", "If($close == 0, $vwap, $close)"]
volume_threshold:
all: ["cum", "0.2 * DayCumsum($volume, '9:45', '14:44')"]
buy: ["current", "$close"]
sell: ["current", "$close"]
strategies: {} # Placeholder
concurrency: 5
59 changes: 59 additions & 0 deletions examples/rl/experiment_config/training/config.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
simulator:
time_per_step: 30
vol_limit: null
env:
concurrency: 1
parallel_mode: dummy
action_interpreter:
class: CategoricalActionInterpreter
kwargs:
values: 14
max_step: 8
module_path: qlib.rl.order_execution.interpreter
state_interpreter:
class: FullHistoryStateInterpreter
kwargs:
data_dim: 6
data_ticks: 240
max_step: 8
processed_data_provider:
class: PickleProcessedDataProvider
module_path: qlib.rl.data.pickle_styled
kwargs:
data_dir: ./data/pickle_dataframe/feature
module_path: qlib.rl.order_execution.interpreter
reward:
class: PAPenaltyReward
kwargs:
penalty: 100.0
module_path: qlib.rl.order_execution.reward
data:
source:
order_dir: ./data/training_order_split
data_dir: ./data/pickle_dataframe/backtest
total_time: 240
default_start_time: 0
default_end_time: 240
proc_data_dim: 6
num_workers: 0
queue_size: 20
network:
class: Recurrent
module_path: qlib.rl.order_execution.network
policy:
class: PPO
kwargs:
lr: 0.0001
module_path: qlib.rl.order_execution.policy
runtime:
seed: 42
use_cuda: false
trainer:
max_epoch: 2
repeat_per_collect: 5
earlystop_patience: 2
episode_per_collect: 20
batch_size: 16
val_every_n_epoch: 1
checkpoint_path: ./checkpoints
checkpoint_every_n_iters: 1
21 changes: 21 additions & 0 deletions examples/rl/scripts/collect_pickle_dataframe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import os
import pickle
import pandas as pd
from tqdm import tqdm

os.makedirs(os.path.join("data", "pickle_dataframe"), exist_ok=True)

for tag in ("backtest", "feature"):
df = pickle.load(open(os.path.join("data", "pickle", f"{tag}.pkl"), "rb"))
df = pd.concat(list(df.values())).reset_index()
df["date"] = df["datetime"].dt.date.astype("datetime64")
instruments = sorted(set(df["instrument"]))

os.makedirs(os.path.join("data", "pickle_dataframe", tag), exist_ok=True)
for instrument in tqdm(instruments):
cur = df[df["instrument"] == instrument].sort_values(by=["datetime"])
cur = cur.set_index(["instrument", "datetime", "date"])
pickle.dump(cur, open(os.path.join("data", "pickle_dataframe", tag, f"{instrument}.pkl"), "wb"))
14 changes: 14 additions & 0 deletions examples/rl/scripts/data_pipeline.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Generate `bin` format data
set -e
python ../../scripts/dump_bin.py dump_all --csv_path ./data/csv --qlib_dir ./data/bin --include_fields open,close,high,low,vwap,volume --symbol_field_name symbol --date_field_name date --freq 1min

# Generate pickle format data
python scripts/gen_pickle_data.py -c scripts/pickle_data_config.yml
if [ -e stat/ ]; then
rm -r stat/
fi
python scripts/collect_pickle_dataframe.py

# Sample orders
python scripts/gen_training_orders.py
python scripts/gen_backtest_orders.py
41 changes: 41 additions & 0 deletions examples/rl/scripts/gen_backtest_orders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import argparse
import os
import pandas as pd
import numpy as np
import pickle

parser = argparse.ArgumentParser()
parser.add_argument("--seed", type=int, default=20220926)
parser.add_argument("--num_order", type=int, default=10)
args = parser.parse_args()

np.random.seed(args.seed)

path = os.path.join("data", "pickle", "backtesttest.pkl") # TODO: rename file
df = pickle.load(open(path, "rb")).reset_index()
df["date"] = df["datetime"].dt.date.astype("datetime64")

instruments = sorted(set(df["instrument"]))
df_list = []
for instrument in instruments:
print(instrument)

cur_df = df[df["instrument"] == instrument]

dates = sorted(set([str(d).split(" ")[0] for d in cur_df["date"]]))

n = args.num_order
df_list.append(
pd.DataFrame({
"date": sorted(np.random.choice(dates, size=n, replace=False)),
"instrument": [instrument] * n,
"amount": np.random.randint(low=3, high=11, size=n) * 100.0,
"order_type": np.random.randint(low=0, high=2, size=n),
}).set_index(["date", "instrument"]),
)

total_df = pd.concat(df_list)
total_df.to_csv("data/backtest_orders.csv")
43 changes: 43 additions & 0 deletions examples/rl/scripts/gen_pickle_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import yaml
import argparse
import os
from copy import deepcopy

from qlib.contrib.data.highfreq_provider import HighFreqProvider

loader = yaml.FullLoader

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-c", "--config", type=str, default="config.yml")
parser.add_argument("-d", "--dest", type=str, default=".")
parser.add_argument("-s", "--split", type=str, choices=["none", "date", "stock", "both"], default="stock")
args = parser.parse_args()

conf = yaml.load(open(args.config), Loader=loader)

for k, v in conf.items():
if isinstance(v, dict) and "path" in v:
v["path"] = os.path.join(args.dest, v["path"])
provider = HighFreqProvider(**conf)

# Gen dataframe
if "feature_conf" in conf:
feature = provider._gen_dataframe(deepcopy(provider.feature_conf))
if "backtest_conf" in conf:
backtest = provider._gen_dataframe(deepcopy(provider.backtest_conf))

provider.feature_conf['path'] = os.path.splitext(provider.feature_conf['path'])[0] + '/'
provider.backtest_conf['path'] = os.path.splitext(provider.backtest_conf['path'])[0] + '/'
# Split by date
if args.split == "date" or args.split == "both":
provider._gen_day_dataset(deepcopy(provider.feature_conf), "feature")
provider._gen_day_dataset(deepcopy(provider.backtest_conf), "backtest")

# Split by stock
if args.split == "stock" or args.split == "both":
provider._gen_stock_dataset(deepcopy(provider.feature_conf), "feature")
provider._gen_stock_dataset(deepcopy(provider.backtest_conf), "backtest")
37 changes: 37 additions & 0 deletions examples/rl/scripts/gen_training_orders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import argparse
import os
import pandas as pd
import numpy as np
import pickle

parser = argparse.ArgumentParser()
parser.add_argument("--seed", type=int, default=20220926)
parser.add_argument("--stock", type=str, default="AAPL")
parser.add_argument("--train_size", type=int, default=10)
parser.add_argument("--valid_size", type=int, default=2)
parser.add_argument("--test_size", type=int, default=2)
args = parser.parse_args()

np.random.seed(args.seed)

os.makedirs(os.path.join("data", "training_order_split"), exist_ok=True)

for group, n in zip(("train", "valid", "test"), (args.train_size, args.valid_size, args.test_size)):
path = os.path.join("data", "pickle", f"backtest{group}.pkl")
df = pickle.load(open(path, "rb")).reset_index()
df["date"] = df["datetime"].dt.date.astype("datetime64")

dates = sorted(set([str(d).split(" ")[0] for d in df["date"]]))

data_df = pd.DataFrame({
"date": sorted(np.random.choice(dates, size=n, replace=False)),
"instrument": [args.stock] * n,
"amount": np.random.randint(low=3, high=11, size=n) * 100.0,
"order_type": [0] * n,
}).set_index(["date", "instrument"])

os.makedirs(os.path.join("data", "training_order_split", group), exist_ok=True)
pickle.dump(data_df, open(os.path.join("data", "training_order_split", group, f"{args.stock}.pkl"), "wb"))
Loading

0 comments on commit 54ef210

Please sign in to comment.