3 changed files with 453 additions and 0 deletions
@ -0,0 +1,427 @@ |
|||
# ========================================= |
|||
# 工业级AI选股系统(最终整合版) |
|||
# 指数择时 + 动态仓位 + 回测 + 调参 |
|||
# ========================================= |
|||
|
|||
import sqlite3 |
|||
import akshare as ak |
|||
import pandas as pd |
|||
|
|||
DB_PATH = "stock_data.db" |
|||
|
|||
# ========================================= |
|||
# 全局参数配置 |
|||
# ========================================= |
|||
CONFIG = { |
|||
|
|||
# ===== 数据 ===== |
|||
"START_DATE": "20230101", |
|||
"END_DATE": "20240401", |
|||
|
|||
# ===== 均线 ===== |
|||
"MA_SHORT": 5, |
|||
"MA_MID": 20, |
|||
"MA_LONG": 60, |
|||
|
|||
# ===== 特征 ===== |
|||
"RETURN_WINDOW": 3, |
|||
"VOL_WINDOW": 5, |
|||
"VOLATILITY_WINDOW": 10, |
|||
|
|||
# ===== 权重 ===== |
|||
"W_MOMENTUM": 0.4, |
|||
"W_TREND": 0.3, |
|||
"W_VOLUME": 0.2, |
|||
"W_VOLATILITY": -0.1, |
|||
|
|||
# ===== 策略 ===== |
|||
"RETURN_THRESH": 0.02, |
|||
"VOL_THRESH": 1.2, |
|||
|
|||
"HOLD_DAYS": 3, |
|||
"STOP_LOSS": -0.05, |
|||
"TAKE_PROFIT": 0.1, |
|||
|
|||
# ===== 成本 ===== |
|||
"FEE": 0.001, |
|||
"SLIPPAGE": 0.001, |
|||
|
|||
# ===== Walk Forward ===== |
|||
"TRAIN_WINDOW": 60, |
|||
"TEST_WINDOW": 20, |
|||
|
|||
# ===== 指数 ===== |
|||
"INDEX_LIST": { |
|||
"hs300": "000300", |
|||
"zz500": "000905", |
|||
"cyb": "399006", |
|||
} |
|||
} |
|||
|
|||
# ========================= |
|||
# 数据库 |
|||
# ========================= |
|||
def init_db(): |
|||
conn = sqlite3.connect(DB_PATH) |
|||
conn.execute(""" |
|||
CREATE TABLE IF NOT EXISTS stock_daily ( |
|||
code TEXT, |
|||
date TEXT, |
|||
open REAL, |
|||
close REAL, |
|||
high REAL, |
|||
low REAL, |
|||
volume REAL, |
|||
PRIMARY KEY (code, date) |
|||
) |
|||
""") |
|||
conn.close() |
|||
|
|||
|
|||
def load_from_db(code): |
|||
conn = sqlite3.connect(DB_PATH) |
|||
df = pd.read_sql_query( |
|||
f"SELECT * FROM stock_daily WHERE code='{code}'", conn |
|||
) |
|||
conn.close() |
|||
return df.sort_values("date") |
|||
|
|||
|
|||
def save_to_db(df, code): |
|||
conn = sqlite3.connect(DB_PATH) |
|||
df['code'] = code |
|||
df.to_sql("stock_daily", conn, if_exists="append", index=False) |
|||
conn.close() |
|||
|
|||
# ========================= |
|||
# 数据获取 |
|||
# ========================= |
|||
def fetch_data(code): |
|||
try: |
|||
df = ak.stock_zh_a_hist( |
|||
symbol=code, |
|||
period="daily", |
|||
start_date=CONFIG["START_DATE"], |
|||
end_date=CONFIG["END_DATE"], |
|||
adjust="qfq" |
|||
) |
|||
except: |
|||
df = ak.stock_zh_a_daily(symbol=code) |
|||
|
|||
if df is None or df.empty: |
|||
return None |
|||
|
|||
df = df.rename(columns={ |
|||
"日期": "date", |
|||
"开盘": "open", |
|||
"收盘": "close", |
|||
"最高": "high", |
|||
"最低": "low", |
|||
"成交量": "volume" |
|||
}) |
|||
|
|||
if "date" not in df.columns: |
|||
return None |
|||
|
|||
return df[["date", "open", "close", "high", "low", "volume"]] |
|||
|
|||
|
|||
def get_data(code): |
|||
try: |
|||
df = load_from_db(code) |
|||
if len(df) > 30: |
|||
return df |
|||
except: |
|||
pass |
|||
|
|||
df = fetch_data(code) |
|||
if df is not None: |
|||
save_to_db(df, code) |
|||
return df |
|||
|
|||
# ========================= |
|||
# 特征工程 |
|||
# ========================= |
|||
def add_features(df): |
|||
|
|||
df['return_n'] = df['close'].pct_change(CONFIG["RETURN_WINDOW"]) |
|||
|
|||
df['ma_short'] = df['close'].rolling(CONFIG["MA_SHORT"]).mean() |
|||
df['ma_mid'] = df['close'].rolling(CONFIG["MA_MID"]).mean() |
|||
df['ma_long'] = df['close'].rolling(CONFIG["MA_LONG"]).mean() |
|||
|
|||
df['vol_ratio'] = df['volume'] / df['volume'].rolling(CONFIG["VOL_WINDOW"]).mean() |
|||
df['volatility'] = df['close'].pct_change().rolling(CONFIG["VOLATILITY_WINDOW"]).std() |
|||
|
|||
df['trend'] = df['ma_mid'] > df['ma_long'] |
|||
|
|||
return df.dropna() |
|||
|
|||
|
|||
def score(df): |
|||
df['score'] = ( |
|||
CONFIG["W_MOMENTUM"] * df['return_n'] + |
|||
CONFIG["W_TREND"] * ((df['ma_short'] - df['ma_mid']) / df['ma_mid']) + |
|||
CONFIG["W_VOLUME"] * df['vol_ratio'] + |
|||
CONFIG["W_VOLATILITY"] * df['volatility'] |
|||
) |
|||
return df |
|||
|
|||
# ========================= |
|||
# 指数模块 |
|||
# ========================= |
|||
def get_index_data(code): |
|||
try: |
|||
df = ak.index_zh_a_hist( |
|||
symbol=code, |
|||
period="daily", |
|||
start_date=CONFIG["START_DATE"], |
|||
end_date=CONFIG["END_DATE"] |
|||
) |
|||
except: |
|||
return None |
|||
|
|||
if df is None or df.empty: |
|||
return None |
|||
|
|||
df = df.rename(columns={"日期": "date", "收盘": "close"}) |
|||
return df[['date', 'close']] |
|||
|
|||
|
|||
def index_trend(df): |
|||
ma20 = df['close'].rolling(CONFIG["MA_MID"]).mean() |
|||
ma60 = df['close'].rolling(CONFIG["MA_LONG"]).mean() |
|||
return ma20.iloc[-1] > ma60.iloc[-1] |
|||
|
|||
|
|||
def market_score(index_data): |
|||
score = 0 |
|||
|
|||
if index_trend(index_data["hs300"]): |
|||
score += 2 |
|||
else: |
|||
score -= 2 |
|||
|
|||
if index_trend(index_data["zz500"]): |
|||
score += 2 |
|||
|
|||
if index_trend(index_data["cyb"]): |
|||
score += 1 |
|||
|
|||
return score |
|||
|
|||
|
|||
def market_strategy(score): |
|||
if score >= 4: |
|||
return 0.7 |
|||
elif score >= 2: |
|||
return 0.4 |
|||
elif score >= 0: |
|||
return 0.2 |
|||
else: |
|||
return 0 |
|||
|
|||
# ========================= |
|||
# 交易函数 |
|||
# ========================= |
|||
def trade_return(prices): |
|||
if len(prices) == 0: |
|||
return 0 |
|||
|
|||
entry = prices[0] * (1 + CONFIG["SLIPPAGE"]) |
|||
ret = 0 |
|||
|
|||
for p in prices: |
|||
exit_price = p * (1 - CONFIG["SLIPPAGE"]) |
|||
r = (exit_price - entry) / entry |
|||
|
|||
if r <= CONFIG["STOP_LOSS"]: |
|||
ret = CONFIG["STOP_LOSS"] |
|||
break |
|||
elif r >= CONFIG["TAKE_PROFIT"]: |
|||
ret = CONFIG["TAKE_PROFIT"] |
|||
break |
|||
else: |
|||
ret = r |
|||
|
|||
return ret - CONFIG["FEE"] * 2 |
|||
|
|||
# ========================= |
|||
# 回测 |
|||
# ========================= |
|||
def backtest_real(df_all, params, position_ratio, init_cash=100000): |
|||
|
|||
cash = init_cash |
|||
equity = [] |
|||
|
|||
dates = sorted(df_all['date'].unique()) |
|||
|
|||
for i in range(len(dates) - params['hold_days']): |
|||
|
|||
today = dates[i] |
|||
|
|||
df_today = df_all[df_all['date'] <= today] |
|||
latest = df_today.groupby('code').tail(1) |
|||
|
|||
picks = latest[ |
|||
(latest['return_n'] > params['return_thresh']) & |
|||
(latest['vol_ratio'] > params['vol_thresh']) & |
|||
(latest['trend']) |
|||
].sort_values('score', ascending=False).head(3) |
|||
|
|||
if len(picks) == 0 or position_ratio == 0: |
|||
equity.append(cash) |
|||
continue |
|||
|
|||
position = cash * position_ratio |
|||
per_stock = position / len(picks) |
|||
|
|||
profit = 0 |
|||
|
|||
for _, row in picks.iterrows(): |
|||
code = row['code'] |
|||
|
|||
prices = df_all[ |
|||
(df_all['code'] == code) & |
|||
(df_all['date'] >= today) |
|||
].sort_values('date')['close'].values[:params['hold_days']] |
|||
|
|||
r = trade_return(prices) |
|||
profit += per_stock * r |
|||
|
|||
cash += profit |
|||
equity.append(cash) |
|||
|
|||
return equity |
|||
|
|||
# ========================= |
|||
# 调参 |
|||
# ========================= |
|||
def simple_grid(df): |
|||
best_score = -999 |
|||
best_params = None |
|||
|
|||
for r in [0.01, 0.02]: |
|||
for v in [1.1, 1.2]: |
|||
|
|||
params = { |
|||
"return_thresh": r, |
|||
"vol_thresh": v, |
|||
"hold_days": CONFIG["HOLD_DAYS"] |
|||
} |
|||
|
|||
equity = backtest_real(df, params, 0.4) |
|||
|
|||
if len(equity) < 10: |
|||
continue |
|||
|
|||
ret = (equity[-1] - equity[0]) / equity[0] |
|||
dd = calc_max_drawdown(equity) |
|||
|
|||
score = ret - dd |
|||
|
|||
if score > best_score: |
|||
best_score = score |
|||
best_params = params |
|||
|
|||
return best_params |
|||
|
|||
# ========================= |
|||
# 评估 |
|||
# ========================= |
|||
def calc_max_drawdown(equity): |
|||
peak = equity[0] |
|||
max_dd = 0 |
|||
|
|||
for x in equity: |
|||
if x > peak: |
|||
peak = x |
|||
dd = (peak - x) / peak |
|||
max_dd = max(max_dd, dd) |
|||
|
|||
return max_dd |
|||
|
|||
# ========================= |
|||
# Walk Forward(带指数) |
|||
# ========================= |
|||
def walk_forward(df_all, index_data): |
|||
|
|||
dates = sorted(df_all['date'].unique()) |
|||
|
|||
window = CONFIG["TRAIN_WINDOW"] |
|||
step = CONFIG["TEST_WINDOW"] |
|||
|
|||
final_equity = [] |
|||
|
|||
for start in range(0, len(dates) - window - step, step): |
|||
|
|||
train_dates = dates[start:start+window] |
|||
test_dates = dates[start+window:start+window+step] |
|||
|
|||
train_df = df_all[df_all['date'].isin(train_dates)] |
|||
test_df = df_all[df_all['date'].isin(test_dates)] |
|||
|
|||
best_params = simple_grid(train_df) |
|||
|
|||
# ===== 指数决定仓位 ===== |
|||
score = market_score(index_data) |
|||
position_ratio = market_strategy(score) |
|||
|
|||
if position_ratio == 0: |
|||
continue |
|||
|
|||
equity = backtest_real(test_df, best_params, position_ratio) |
|||
|
|||
if len(final_equity) == 0: |
|||
final_equity = equity |
|||
else: |
|||
base = final_equity[-1] |
|||
equity = [base + (x - equity[0]) for x in equity] |
|||
final_equity.extend(equity) |
|||
|
|||
return final_equity |
|||
|
|||
# ========================= |
|||
# 主程序 |
|||
# ========================= |
|||
def main(): |
|||
|
|||
init_db() |
|||
|
|||
codes = ["000001", "000002", "600519", "600036", "601318"] |
|||
|
|||
all_data = [] |
|||
|
|||
for code in codes: |
|||
df = get_data(code) |
|||
if df is None: |
|||
continue |
|||
|
|||
df['code'] = code |
|||
df = add_features(df) |
|||
df = score(df) |
|||
all_data.append(df) |
|||
|
|||
df_all = pd.concat(all_data) |
|||
|
|||
# ===== 指数一次性获取 ===== |
|||
index_data = {} |
|||
for name, code in CONFIG["INDEX_LIST"].items(): |
|||
df = get_index_data(code) |
|||
if df is not None: |
|||
index_data[name] = df |
|||
|
|||
equity = walk_forward(df_all, index_data) |
|||
|
|||
print("\n===== 实盘级回测 =====") |
|||
print("最终资金:", round(equity[-1], 2)) |
|||
|
|||
ret = (equity[-1] - equity[0]) / equity[0] |
|||
dd = calc_max_drawdown(equity) |
|||
|
|||
print("总收益:", round(ret, 2)) |
|||
print("最大回撤:", round(dd, 2)) |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
main() |
|||
@ -0,0 +1,26 @@ |
|||
import sqlite3 |
|||
import pandas as pd |
|||
import os |
|||
|
|||
DB_PATH = "stock_data.db" |
|||
|
|||
|
|||
def init_db(): |
|||
conn = sqlite3.connect(DB_PATH) |
|||
cursor = conn.cursor() |
|||
|
|||
cursor.execute(""" |
|||
CREATE TABLE IF NOT EXISTS stock_daily ( |
|||
code TEXT, |
|||
date TEXT, |
|||
open REAL, |
|||
close REAL, |
|||
high REAL, |
|||
low REAL, |
|||
volume REAL, |
|||
PRIMARY KEY (code, date) |
|||
) |
|||
""") |
|||
|
|||
conn.commit() |
|||
conn.close() |
|||
Binary file not shown.
Loading…
Reference in new issue