You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
427 lines
9.6 KiB
427 lines
9.6 KiB
# =========================================
|
|
# 工业级AI选股系统(最终整合版)
|
|
# 指数择时 + 动态仓位 + 回测 + 调参
|
|
# =========================================
|
|
|
|
import sqlite3
|
|
import akshare as ak
|
|
import pandas as pd
|
|
|
|
DB_PATH = "stock_data.db"
|
|
|
|
# =========================================
|
|
# 全局参数配置
|
|
# =========================================
|
|
CONFIG = {
|
|
|
|
# ===== 数据 =====
|
|
"START_DATE": "20230101",
|
|
"END_DATE": "20240401",
|
|
|
|
# ===== 均线 =====
|
|
"MA_SHORT": 5,
|
|
"MA_MID": 20,
|
|
"MA_LONG": 60,
|
|
|
|
# ===== 特征 =====
|
|
"RETURN_WINDOW": 3,
|
|
"VOL_WINDOW": 5,
|
|
"VOLATILITY_WINDOW": 10,
|
|
|
|
# ===== 权重 =====
|
|
"W_MOMENTUM": 0.4,
|
|
"W_TREND": 0.3,
|
|
"W_VOLUME": 0.2,
|
|
"W_VOLATILITY": -0.1,
|
|
|
|
# ===== 策略 =====
|
|
"RETURN_THRESH": 0.02,
|
|
"VOL_THRESH": 1.2,
|
|
|
|
"HOLD_DAYS": 3,
|
|
"STOP_LOSS": -0.05,
|
|
"TAKE_PROFIT": 0.1,
|
|
|
|
# ===== 成本 =====
|
|
"FEE": 0.001,
|
|
"SLIPPAGE": 0.001,
|
|
|
|
# ===== Walk Forward =====
|
|
"TRAIN_WINDOW": 60,
|
|
"TEST_WINDOW": 20,
|
|
|
|
# ===== 指数 =====
|
|
"INDEX_LIST": {
|
|
"hs300": "000300",
|
|
"zz500": "000905",
|
|
"cyb": "399006",
|
|
}
|
|
}
|
|
|
|
# =========================
|
|
# 数据库
|
|
# =========================
|
|
def init_db():
|
|
conn = sqlite3.connect(DB_PATH)
|
|
conn.execute("""
|
|
CREATE TABLE IF NOT EXISTS stock_daily (
|
|
code TEXT,
|
|
date TEXT,
|
|
open REAL,
|
|
close REAL,
|
|
high REAL,
|
|
low REAL,
|
|
volume REAL,
|
|
PRIMARY KEY (code, date)
|
|
)
|
|
""")
|
|
conn.close()
|
|
|
|
|
|
def load_from_db(code):
|
|
conn = sqlite3.connect(DB_PATH)
|
|
df = pd.read_sql_query(
|
|
f"SELECT * FROM stock_daily WHERE code='{code}'", conn
|
|
)
|
|
conn.close()
|
|
return df.sort_values("date")
|
|
|
|
|
|
def save_to_db(df, code):
|
|
conn = sqlite3.connect(DB_PATH)
|
|
df['code'] = code
|
|
df.to_sql("stock_daily", conn, if_exists="append", index=False)
|
|
conn.close()
|
|
|
|
# =========================
|
|
# 数据获取
|
|
# =========================
|
|
def fetch_data(code):
|
|
try:
|
|
df = ak.stock_zh_a_hist(
|
|
symbol=code,
|
|
period="daily",
|
|
start_date=CONFIG["START_DATE"],
|
|
end_date=CONFIG["END_DATE"],
|
|
adjust="qfq"
|
|
)
|
|
except:
|
|
df = ak.stock_zh_a_daily(symbol=code)
|
|
|
|
if df is None or df.empty:
|
|
return None
|
|
|
|
df = df.rename(columns={
|
|
"日期": "date",
|
|
"开盘": "open",
|
|
"收盘": "close",
|
|
"最高": "high",
|
|
"最低": "low",
|
|
"成交量": "volume"
|
|
})
|
|
|
|
if "date" not in df.columns:
|
|
return None
|
|
|
|
return df[["date", "open", "close", "high", "low", "volume"]]
|
|
|
|
|
|
def get_data(code):
|
|
try:
|
|
df = load_from_db(code)
|
|
if len(df) > 30:
|
|
return df
|
|
except:
|
|
pass
|
|
|
|
df = fetch_data(code)
|
|
if df is not None:
|
|
save_to_db(df, code)
|
|
return df
|
|
|
|
# =========================
|
|
# 特征工程
|
|
# =========================
|
|
def add_features(df):
|
|
|
|
df['return_n'] = df['close'].pct_change(CONFIG["RETURN_WINDOW"])
|
|
|
|
df['ma_short'] = df['close'].rolling(CONFIG["MA_SHORT"]).mean()
|
|
df['ma_mid'] = df['close'].rolling(CONFIG["MA_MID"]).mean()
|
|
df['ma_long'] = df['close'].rolling(CONFIG["MA_LONG"]).mean()
|
|
|
|
df['vol_ratio'] = df['volume'] / df['volume'].rolling(CONFIG["VOL_WINDOW"]).mean()
|
|
df['volatility'] = df['close'].pct_change().rolling(CONFIG["VOLATILITY_WINDOW"]).std()
|
|
|
|
df['trend'] = df['ma_mid'] > df['ma_long']
|
|
|
|
return df.dropna()
|
|
|
|
|
|
def score(df):
|
|
df['score'] = (
|
|
CONFIG["W_MOMENTUM"] * df['return_n'] +
|
|
CONFIG["W_TREND"] * ((df['ma_short'] - df['ma_mid']) / df['ma_mid']) +
|
|
CONFIG["W_VOLUME"] * df['vol_ratio'] +
|
|
CONFIG["W_VOLATILITY"] * df['volatility']
|
|
)
|
|
return df
|
|
|
|
# =========================
|
|
# 指数模块
|
|
# =========================
|
|
def get_index_data(code):
|
|
try:
|
|
df = ak.index_zh_a_hist(
|
|
symbol=code,
|
|
period="daily",
|
|
start_date=CONFIG["START_DATE"],
|
|
end_date=CONFIG["END_DATE"]
|
|
)
|
|
except:
|
|
return None
|
|
|
|
if df is None or df.empty:
|
|
return None
|
|
|
|
df = df.rename(columns={"日期": "date", "收盘": "close"})
|
|
return df[['date', 'close']]
|
|
|
|
|
|
def index_trend(df):
|
|
ma20 = df['close'].rolling(CONFIG["MA_MID"]).mean()
|
|
ma60 = df['close'].rolling(CONFIG["MA_LONG"]).mean()
|
|
return ma20.iloc[-1] > ma60.iloc[-1]
|
|
|
|
|
|
def market_score(index_data):
|
|
score = 0
|
|
|
|
if index_trend(index_data["hs300"]):
|
|
score += 2
|
|
else:
|
|
score -= 2
|
|
|
|
if index_trend(index_data["zz500"]):
|
|
score += 2
|
|
|
|
if index_trend(index_data["cyb"]):
|
|
score += 1
|
|
|
|
return score
|
|
|
|
|
|
def market_strategy(score):
|
|
if score >= 4:
|
|
return 0.7
|
|
elif score >= 2:
|
|
return 0.4
|
|
elif score >= 0:
|
|
return 0.2
|
|
else:
|
|
return 0
|
|
|
|
# =========================
|
|
# 交易函数
|
|
# =========================
|
|
def trade_return(prices):
|
|
if len(prices) == 0:
|
|
return 0
|
|
|
|
entry = prices[0] * (1 + CONFIG["SLIPPAGE"])
|
|
ret = 0
|
|
|
|
for p in prices:
|
|
exit_price = p * (1 - CONFIG["SLIPPAGE"])
|
|
r = (exit_price - entry) / entry
|
|
|
|
if r <= CONFIG["STOP_LOSS"]:
|
|
ret = CONFIG["STOP_LOSS"]
|
|
break
|
|
elif r >= CONFIG["TAKE_PROFIT"]:
|
|
ret = CONFIG["TAKE_PROFIT"]
|
|
break
|
|
else:
|
|
ret = r
|
|
|
|
return ret - CONFIG["FEE"] * 2
|
|
|
|
# =========================
|
|
# 回测
|
|
# =========================
|
|
def backtest_real(df_all, params, position_ratio, init_cash=100000):
|
|
|
|
cash = init_cash
|
|
equity = []
|
|
|
|
dates = sorted(df_all['date'].unique())
|
|
|
|
for i in range(len(dates) - params['hold_days']):
|
|
|
|
today = dates[i]
|
|
|
|
df_today = df_all[df_all['date'] <= today]
|
|
latest = df_today.groupby('code').tail(1)
|
|
|
|
picks = latest[
|
|
(latest['return_n'] > params['return_thresh']) &
|
|
(latest['vol_ratio'] > params['vol_thresh']) &
|
|
(latest['trend'])
|
|
].sort_values('score', ascending=False).head(3)
|
|
|
|
if len(picks) == 0 or position_ratio == 0:
|
|
equity.append(cash)
|
|
continue
|
|
|
|
position = cash * position_ratio
|
|
per_stock = position / len(picks)
|
|
|
|
profit = 0
|
|
|
|
for _, row in picks.iterrows():
|
|
code = row['code']
|
|
|
|
prices = df_all[
|
|
(df_all['code'] == code) &
|
|
(df_all['date'] >= today)
|
|
].sort_values('date')['close'].values[:params['hold_days']]
|
|
|
|
r = trade_return(prices)
|
|
profit += per_stock * r
|
|
|
|
cash += profit
|
|
equity.append(cash)
|
|
|
|
return equity
|
|
|
|
# =========================
|
|
# 调参
|
|
# =========================
|
|
def simple_grid(df):
|
|
best_score = -999
|
|
best_params = None
|
|
|
|
for r in [0.01, 0.02]:
|
|
for v in [1.1, 1.2]:
|
|
|
|
params = {
|
|
"return_thresh": r,
|
|
"vol_thresh": v,
|
|
"hold_days": CONFIG["HOLD_DAYS"]
|
|
}
|
|
|
|
equity = backtest_real(df, params, 0.4)
|
|
|
|
if len(equity) < 10:
|
|
continue
|
|
|
|
ret = (equity[-1] - equity[0]) / equity[0]
|
|
dd = calc_max_drawdown(equity)
|
|
|
|
score = ret - dd
|
|
|
|
if score > best_score:
|
|
best_score = score
|
|
best_params = params
|
|
|
|
return best_params
|
|
|
|
# =========================
|
|
# 评估
|
|
# =========================
|
|
def calc_max_drawdown(equity):
|
|
peak = equity[0]
|
|
max_dd = 0
|
|
|
|
for x in equity:
|
|
if x > peak:
|
|
peak = x
|
|
dd = (peak - x) / peak
|
|
max_dd = max(max_dd, dd)
|
|
|
|
return max_dd
|
|
|
|
# =========================
|
|
# Walk Forward(带指数)
|
|
# =========================
|
|
def walk_forward(df_all, index_data):
|
|
|
|
dates = sorted(df_all['date'].unique())
|
|
|
|
window = CONFIG["TRAIN_WINDOW"]
|
|
step = CONFIG["TEST_WINDOW"]
|
|
|
|
final_equity = []
|
|
|
|
for start in range(0, len(dates) - window - step, step):
|
|
|
|
train_dates = dates[start:start+window]
|
|
test_dates = dates[start+window:start+window+step]
|
|
|
|
train_df = df_all[df_all['date'].isin(train_dates)]
|
|
test_df = df_all[df_all['date'].isin(test_dates)]
|
|
|
|
best_params = simple_grid(train_df)
|
|
|
|
# ===== 指数决定仓位 =====
|
|
score = market_score(index_data)
|
|
position_ratio = market_strategy(score)
|
|
|
|
if position_ratio == 0:
|
|
continue
|
|
|
|
equity = backtest_real(test_df, best_params, position_ratio)
|
|
|
|
if len(final_equity) == 0:
|
|
final_equity = equity
|
|
else:
|
|
base = final_equity[-1]
|
|
equity = [base + (x - equity[0]) for x in equity]
|
|
final_equity.extend(equity)
|
|
|
|
return final_equity
|
|
|
|
# =========================
|
|
# 主程序
|
|
# =========================
|
|
def main():
|
|
|
|
init_db()
|
|
|
|
codes = ["000001", "000002", "600519", "600036", "601318"]
|
|
|
|
all_data = []
|
|
|
|
for code in codes:
|
|
df = get_data(code)
|
|
if df is None:
|
|
continue
|
|
|
|
df['code'] = code
|
|
df = add_features(df)
|
|
df = score(df)
|
|
all_data.append(df)
|
|
|
|
df_all = pd.concat(all_data)
|
|
|
|
# ===== 指数一次性获取 =====
|
|
index_data = {}
|
|
for name, code in CONFIG["INDEX_LIST"].items():
|
|
df = get_index_data(code)
|
|
if df is not None:
|
|
index_data[name] = df
|
|
|
|
equity = walk_forward(df_all, index_data)
|
|
|
|
print("\n===== 实盘级回测 =====")
|
|
print("最终资金:", round(equity[-1], 2))
|
|
|
|
ret = (equity[-1] - equity[0]) / equity[0]
|
|
dd = calc_max_drawdown(equity)
|
|
|
|
print("总收益:", round(ret, 2))
|
|
print("最大回撤:", round(dd, 2))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|