Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate amc4th training #1316

Merged
merged 3 commits into from
Oct 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ qlib/VERSION.txt
qlib/data/_libs/expanding.cpp
qlib/data/_libs/rolling.cpp
examples/estimator/estimator_example/
examples/rl/data/
examples/rl/checkpoints/
examples/rl/outputs/

*.egg-info/

Expand Down
55 changes: 55 additions & 0 deletions examples/rl/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
This folder contains a simple example of how to run Qlib RL. It contains:

```
.
├── experiment_config
│ ├── backtest # Backtest config
│ └── training # Training config
├── README.md # Readme (the current file)
└── scripts # Scripts for data pre-processing
```

## Data preparation

Use [AzCopy](https://learn.microsoft.com/en-us/azure/storage/common/storage-use-azcopy-v10) to download data:

```
azcopy copy https://qlibpublic.blob.core.windows.net/data/rl/qlib_rl_example_data ./ --recursive
mv qlib_rl_example_data data
```

The downloaded data will be placed at `./data`. The original data are in `data/csv`. To create all data needed by the case, run:

```
bash scripts/data_pipeline.sh
```

After the execution finishes, the `data/` directory should be like:

```
data
├── backtest_orders.csv
├── bin
├── csv
├── pickle
├── pickle_dataframe
└── training_order_split
```

## Run training

Run:

```
python ../../qlib/rl/contrib/train_onpolicy.py --config_path ./experiment_config/training/config.yml
```

After training, checkpoints will be stored under `checkpoints/`.

## Run backtest

```
python ../../qlib/rl/contrib/backtest.py --config_path ./experiment_config/backtest/config.py
```

The backtest workflow will use the trained model in `checkpoints/`. The backtest summary can be found in `outputs/`.
53 changes: 53 additions & 0 deletions examples/rl/experiment_config/backtest/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
_base_ = ["./twap.yml"]

strategies = {
"_delete_": True,
"30min": {
"class": "TWAPStrategy",
"module_path": "qlib.contrib.strategy.rule_strategy",
"kwargs": {},
},
"1day": {
"class": "SAOEIntStrategy",
"module_path": "qlib.rl.order_execution.strategy",
"kwargs": {
"state_interpreter": {
"class": "FullHistoryStateInterpreter",
"module_path": "qlib.rl.order_execution.interpreter",
"kwargs": {
"max_step": 8,
"data_ticks": 240,
"data_dim": 6,
"processed_data_provider": {
"class": "PickleProcessedDataProvider",
"module_path": "qlib.rl.data.pickle_styled",
"kwargs": {
"data_dir": "./data/pickle_dataframe/feature",
},
},
},
},
"action_interpreter": {
"class": "CategoricalActionInterpreter",
"module_path": "qlib.rl.order_execution.interpreter",
"kwargs": {
"values": 14,
"max_step": 8,
},
},
"network": {
"class": "Recurrent",
"module_path": "qlib.rl.order_execution.network",
"kwargs": {},
},
"policy": {
"class": "PPO",
"module_path": "qlib.rl.order_execution.policy",
"kwargs": {
"lr": 1.0e-4,
"weight_file": "./checkpoints/latest.pth",
},
},
},
},
}
21 changes: 21 additions & 0 deletions examples/rl/experiment_config/backtest/twap.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
order_file: ./data/backtest_orders.csv
start_time: "9:45"
end_time: "14:44"
qlib:
provider_uri_1min: ./data/bin
feature_root_dir: ./data/pickle
feature_columns_today: [
"$open", "$high", "$low", "$close", "$vwap", "$volume",
]
feature_columns_yesterday: [
"$open_v1", "$high_v1", "$low_v1", "$close_v1", "$vwap_v1", "$volume_v1",
]
exchange:
limit_threshold: ['$close == 0', '$close == 0']
deal_price: ["If($close == 0, $vwap, $close)", "If($close == 0, $vwap, $close)"]
volume_threshold:
all: ["cum", "0.2 * DayCumsum($volume, '9:45', '14:44')"]
buy: ["current", "$close"]
sell: ["current", "$close"]
strategies: {} # Placeholder
concurrency: 5
59 changes: 59 additions & 0 deletions examples/rl/experiment_config/training/config.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
simulator:
time_per_step: 30
vol_limit: null
env:
concurrency: 1
parallel_mode: dummy
action_interpreter:
class: CategoricalActionInterpreter
kwargs:
values: 14
max_step: 8
module_path: qlib.rl.order_execution.interpreter
state_interpreter:
class: FullHistoryStateInterpreter
kwargs:
data_dim: 6
data_ticks: 240
max_step: 8
processed_data_provider:
class: PickleProcessedDataProvider
module_path: qlib.rl.data.pickle_styled
kwargs:
data_dir: ./data/pickle_dataframe/feature
module_path: qlib.rl.order_execution.interpreter
reward:
class: PAPenaltyReward
kwargs:
penalty: 100.0
module_path: qlib.rl.order_execution.reward
data:
source:
order_dir: ./data/training_order_split
data_dir: ./data/pickle_dataframe/backtest
total_time: 240
default_start_time: 0
default_end_time: 240
proc_data_dim: 6
num_workers: 0
queue_size: 20
network:
class: Recurrent
module_path: qlib.rl.order_execution.network
policy:
class: PPO
kwargs:
lr: 0.0001
module_path: qlib.rl.order_execution.policy
runtime:
seed: 42
use_cuda: false
trainer:
max_epoch: 2
repeat_per_collect: 5
earlystop_patience: 2
episode_per_collect: 20
batch_size: 16
val_every_n_epoch: 1
checkpoint_path: ./checkpoints
checkpoint_every_n_iters: 1
21 changes: 21 additions & 0 deletions examples/rl/scripts/collect_pickle_dataframe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import os
import pickle
import pandas as pd
from tqdm import tqdm

os.makedirs(os.path.join("data", "pickle_dataframe"), exist_ok=True)

for tag in ("backtest", "feature"):
df = pickle.load(open(os.path.join("data", "pickle", f"{tag}.pkl"), "rb"))
df = pd.concat(list(df.values())).reset_index()
df["date"] = df["datetime"].dt.date.astype("datetime64")
instruments = sorted(set(df["instrument"]))

os.makedirs(os.path.join("data", "pickle_dataframe", tag), exist_ok=True)
for instrument in tqdm(instruments):
cur = df[df["instrument"] == instrument].sort_values(by=["datetime"])
cur = cur.set_index(["instrument", "datetime", "date"])
pickle.dump(cur, open(os.path.join("data", "pickle_dataframe", tag, f"{instrument}.pkl"), "wb"))
14 changes: 14 additions & 0 deletions examples/rl/scripts/data_pipeline.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Generate `bin` format data
set -e
python ../../scripts/dump_bin.py dump_all --csv_path ./data/csv --qlib_dir ./data/bin --include_fields open,close,high,low,vwap,volume --symbol_field_name symbol --date_field_name date --freq 1min

# Generate pickle format data
python scripts/gen_pickle_data.py -c scripts/pickle_data_config.yml
if [ -e stat/ ]; then
rm -r stat/
fi
python scripts/collect_pickle_dataframe.py

# Sample orders
python scripts/gen_training_orders.py
python scripts/gen_backtest_orders.py
41 changes: 41 additions & 0 deletions examples/rl/scripts/gen_backtest_orders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import argparse
import os
import pandas as pd
import numpy as np
import pickle

parser = argparse.ArgumentParser()
parser.add_argument("--seed", type=int, default=20220926)
parser.add_argument("--num_order", type=int, default=10)
args = parser.parse_args()

np.random.seed(args.seed)

path = os.path.join("data", "pickle", "backtesttest.pkl") # TODO: rename file
df = pickle.load(open(path, "rb")).reset_index()
df["date"] = df["datetime"].dt.date.astype("datetime64")

instruments = sorted(set(df["instrument"]))
df_list = []
for instrument in instruments:
print(instrument)

cur_df = df[df["instrument"] == instrument]

dates = sorted(set([str(d).split(" ")[0] for d in cur_df["date"]]))

n = args.num_order
df_list.append(
pd.DataFrame({
"date": sorted(np.random.choice(dates, size=n, replace=False)),
"instrument": [instrument] * n,
"amount": np.random.randint(low=3, high=11, size=n) * 100.0,
"order_type": np.random.randint(low=0, high=2, size=n),
}).set_index(["date", "instrument"]),
)

total_df = pd.concat(df_list)
total_df.to_csv("data/backtest_orders.csv")
43 changes: 43 additions & 0 deletions examples/rl/scripts/gen_pickle_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import yaml
import argparse
import os
from copy import deepcopy

from qlib.contrib.data.highfreq_provider import HighFreqProvider
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please replace print_log in highfreq_provider with qlib.logging.


loader = yaml.FullLoader

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-c", "--config", type=str, default="config.yml")
parser.add_argument("-d", "--dest", type=str, default=".")
parser.add_argument("-s", "--split", type=str, choices=["none", "date", "stock", "both"], default="stock")
args = parser.parse_args()

conf = yaml.load(open(args.config), Loader=loader)

for k, v in conf.items():
if isinstance(v, dict) and "path" in v:
v["path"] = os.path.join(args.dest, v["path"])
provider = HighFreqProvider(**conf)

# Gen dataframe
if "feature_conf" in conf:
feature = provider._gen_dataframe(deepcopy(provider.feature_conf))
if "backtest_conf" in conf:
backtest = provider._gen_dataframe(deepcopy(provider.backtest_conf))

provider.feature_conf['path'] = os.path.splitext(provider.feature_conf['path'])[0] + '/'
provider.backtest_conf['path'] = os.path.splitext(provider.backtest_conf['path'])[0] + '/'
# Split by date
if args.split == "date" or args.split == "both":
provider._gen_day_dataset(deepcopy(provider.feature_conf), "feature")
provider._gen_day_dataset(deepcopy(provider.backtest_conf), "backtest")

# Split by stock
if args.split == "stock" or args.split == "both":
provider._gen_stock_dataset(deepcopy(provider.feature_conf), "feature")
provider._gen_stock_dataset(deepcopy(provider.backtest_conf), "backtest")
37 changes: 37 additions & 0 deletions examples/rl/scripts/gen_training_orders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import argparse
import os
import pandas as pd
import numpy as np
import pickle

parser = argparse.ArgumentParser()
parser.add_argument("--seed", type=int, default=20220926)
parser.add_argument("--stock", type=str, default="AAPL")
parser.add_argument("--train_size", type=int, default=10)
parser.add_argument("--valid_size", type=int, default=2)
parser.add_argument("--test_size", type=int, default=2)
args = parser.parse_args()

np.random.seed(args.seed)

os.makedirs(os.path.join("data", "training_order_split"), exist_ok=True)

for group, n in zip(("train", "valid", "test"), (args.train_size, args.valid_size, args.test_size)):
path = os.path.join("data", "pickle", f"backtest{group}.pkl")
df = pickle.load(open(path, "rb")).reset_index()
df["date"] = df["datetime"].dt.date.astype("datetime64")

dates = sorted(set([str(d).split(" ")[0] for d in df["date"]]))

data_df = pd.DataFrame({
"date": sorted(np.random.choice(dates, size=n, replace=False)),
"instrument": [args.stock] * n,
"amount": np.random.randint(low=3, high=11, size=n) * 100.0,
"order_type": [0] * n,
}).set_index(["date", "instrument"])

os.makedirs(os.path.join("data", "training_order_split", group), exist_ok=True)
pickle.dump(data_df, open(os.path.join("data", "training_order_split", group, f"{args.stock}.pkl"), "wb"))
Loading