3 changed files with 453 additions and 0 deletions
@ -0,0 +1,427 @@ |
|||||
|
# ========================================= |
||||
|
# 工业级AI选股系统(最终整合版) |
||||
|
# 指数择时 + 动态仓位 + 回测 + 调参 |
||||
|
# ========================================= |
||||
|
|
||||
|
import sqlite3 |
||||
|
import akshare as ak |
||||
|
import pandas as pd |
||||
|
|
||||
|
DB_PATH = "stock_data.db" |
||||
|
|
||||
|
# ========================================= |
||||
|
# 全局参数配置 |
||||
|
# ========================================= |
||||
|
CONFIG = { |
||||
|
|
||||
|
# ===== 数据 ===== |
||||
|
"START_DATE": "20230101", |
||||
|
"END_DATE": "20240401", |
||||
|
|
||||
|
# ===== 均线 ===== |
||||
|
"MA_SHORT": 5, |
||||
|
"MA_MID": 20, |
||||
|
"MA_LONG": 60, |
||||
|
|
||||
|
# ===== 特征 ===== |
||||
|
"RETURN_WINDOW": 3, |
||||
|
"VOL_WINDOW": 5, |
||||
|
"VOLATILITY_WINDOW": 10, |
||||
|
|
||||
|
# ===== 权重 ===== |
||||
|
"W_MOMENTUM": 0.4, |
||||
|
"W_TREND": 0.3, |
||||
|
"W_VOLUME": 0.2, |
||||
|
"W_VOLATILITY": -0.1, |
||||
|
|
||||
|
# ===== 策略 ===== |
||||
|
"RETURN_THRESH": 0.02, |
||||
|
"VOL_THRESH": 1.2, |
||||
|
|
||||
|
"HOLD_DAYS": 3, |
||||
|
"STOP_LOSS": -0.05, |
||||
|
"TAKE_PROFIT": 0.1, |
||||
|
|
||||
|
# ===== 成本 ===== |
||||
|
"FEE": 0.001, |
||||
|
"SLIPPAGE": 0.001, |
||||
|
|
||||
|
# ===== Walk Forward ===== |
||||
|
"TRAIN_WINDOW": 60, |
||||
|
"TEST_WINDOW": 20, |
||||
|
|
||||
|
# ===== 指数 ===== |
||||
|
"INDEX_LIST": { |
||||
|
"hs300": "000300", |
||||
|
"zz500": "000905", |
||||
|
"cyb": "399006", |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
# ========================= |
||||
|
# 数据库 |
||||
|
# ========================= |
||||
|
def init_db(): |
||||
|
conn = sqlite3.connect(DB_PATH) |
||||
|
conn.execute(""" |
||||
|
CREATE TABLE IF NOT EXISTS stock_daily ( |
||||
|
code TEXT, |
||||
|
date TEXT, |
||||
|
open REAL, |
||||
|
close REAL, |
||||
|
high REAL, |
||||
|
low REAL, |
||||
|
volume REAL, |
||||
|
PRIMARY KEY (code, date) |
||||
|
) |
||||
|
""") |
||||
|
conn.close() |
||||
|
|
||||
|
|
||||
|
def load_from_db(code): |
||||
|
conn = sqlite3.connect(DB_PATH) |
||||
|
df = pd.read_sql_query( |
||||
|
f"SELECT * FROM stock_daily WHERE code='{code}'", conn |
||||
|
) |
||||
|
conn.close() |
||||
|
return df.sort_values("date") |
||||
|
|
||||
|
|
||||
|
def save_to_db(df, code): |
||||
|
conn = sqlite3.connect(DB_PATH) |
||||
|
df['code'] = code |
||||
|
df.to_sql("stock_daily", conn, if_exists="append", index=False) |
||||
|
conn.close() |
||||
|
|
||||
|
# ========================= |
||||
|
# 数据获取 |
||||
|
# ========================= |
||||
|
def fetch_data(code): |
||||
|
try: |
||||
|
df = ak.stock_zh_a_hist( |
||||
|
symbol=code, |
||||
|
period="daily", |
||||
|
start_date=CONFIG["START_DATE"], |
||||
|
end_date=CONFIG["END_DATE"], |
||||
|
adjust="qfq" |
||||
|
) |
||||
|
except: |
||||
|
df = ak.stock_zh_a_daily(symbol=code) |
||||
|
|
||||
|
if df is None or df.empty: |
||||
|
return None |
||||
|
|
||||
|
df = df.rename(columns={ |
||||
|
"日期": "date", |
||||
|
"开盘": "open", |
||||
|
"收盘": "close", |
||||
|
"最高": "high", |
||||
|
"最低": "low", |
||||
|
"成交量": "volume" |
||||
|
}) |
||||
|
|
||||
|
if "date" not in df.columns: |
||||
|
return None |
||||
|
|
||||
|
return df[["date", "open", "close", "high", "low", "volume"]] |
||||
|
|
||||
|
|
||||
|
def get_data(code): |
||||
|
try: |
||||
|
df = load_from_db(code) |
||||
|
if len(df) > 30: |
||||
|
return df |
||||
|
except: |
||||
|
pass |
||||
|
|
||||
|
df = fetch_data(code) |
||||
|
if df is not None: |
||||
|
save_to_db(df, code) |
||||
|
return df |
||||
|
|
||||
|
# ========================= |
||||
|
# 特征工程 |
||||
|
# ========================= |
||||
|
def add_features(df): |
||||
|
|
||||
|
df['return_n'] = df['close'].pct_change(CONFIG["RETURN_WINDOW"]) |
||||
|
|
||||
|
df['ma_short'] = df['close'].rolling(CONFIG["MA_SHORT"]).mean() |
||||
|
df['ma_mid'] = df['close'].rolling(CONFIG["MA_MID"]).mean() |
||||
|
df['ma_long'] = df['close'].rolling(CONFIG["MA_LONG"]).mean() |
||||
|
|
||||
|
df['vol_ratio'] = df['volume'] / df['volume'].rolling(CONFIG["VOL_WINDOW"]).mean() |
||||
|
df['volatility'] = df['close'].pct_change().rolling(CONFIG["VOLATILITY_WINDOW"]).std() |
||||
|
|
||||
|
df['trend'] = df['ma_mid'] > df['ma_long'] |
||||
|
|
||||
|
return df.dropna() |
||||
|
|
||||
|
|
||||
|
def score(df): |
||||
|
df['score'] = ( |
||||
|
CONFIG["W_MOMENTUM"] * df['return_n'] + |
||||
|
CONFIG["W_TREND"] * ((df['ma_short'] - df['ma_mid']) / df['ma_mid']) + |
||||
|
CONFIG["W_VOLUME"] * df['vol_ratio'] + |
||||
|
CONFIG["W_VOLATILITY"] * df['volatility'] |
||||
|
) |
||||
|
return df |
||||
|
|
||||
|
# ========================= |
||||
|
# 指数模块 |
||||
|
# ========================= |
||||
|
def get_index_data(code): |
||||
|
try: |
||||
|
df = ak.index_zh_a_hist( |
||||
|
symbol=code, |
||||
|
period="daily", |
||||
|
start_date=CONFIG["START_DATE"], |
||||
|
end_date=CONFIG["END_DATE"] |
||||
|
) |
||||
|
except: |
||||
|
return None |
||||
|
|
||||
|
if df is None or df.empty: |
||||
|
return None |
||||
|
|
||||
|
df = df.rename(columns={"日期": "date", "收盘": "close"}) |
||||
|
return df[['date', 'close']] |
||||
|
|
||||
|
|
||||
|
def index_trend(df): |
||||
|
ma20 = df['close'].rolling(CONFIG["MA_MID"]).mean() |
||||
|
ma60 = df['close'].rolling(CONFIG["MA_LONG"]).mean() |
||||
|
return ma20.iloc[-1] > ma60.iloc[-1] |
||||
|
|
||||
|
|
||||
|
def market_score(index_data): |
||||
|
score = 0 |
||||
|
|
||||
|
if index_trend(index_data["hs300"]): |
||||
|
score += 2 |
||||
|
else: |
||||
|
score -= 2 |
||||
|
|
||||
|
if index_trend(index_data["zz500"]): |
||||
|
score += 2 |
||||
|
|
||||
|
if index_trend(index_data["cyb"]): |
||||
|
score += 1 |
||||
|
|
||||
|
return score |
||||
|
|
||||
|
|
||||
|
def market_strategy(score): |
||||
|
if score >= 4: |
||||
|
return 0.7 |
||||
|
elif score >= 2: |
||||
|
return 0.4 |
||||
|
elif score >= 0: |
||||
|
return 0.2 |
||||
|
else: |
||||
|
return 0 |
||||
|
|
||||
|
# ========================= |
||||
|
# 交易函数 |
||||
|
# ========================= |
||||
|
def trade_return(prices): |
||||
|
if len(prices) == 0: |
||||
|
return 0 |
||||
|
|
||||
|
entry = prices[0] * (1 + CONFIG["SLIPPAGE"]) |
||||
|
ret = 0 |
||||
|
|
||||
|
for p in prices: |
||||
|
exit_price = p * (1 - CONFIG["SLIPPAGE"]) |
||||
|
r = (exit_price - entry) / entry |
||||
|
|
||||
|
if r <= CONFIG["STOP_LOSS"]: |
||||
|
ret = CONFIG["STOP_LOSS"] |
||||
|
break |
||||
|
elif r >= CONFIG["TAKE_PROFIT"]: |
||||
|
ret = CONFIG["TAKE_PROFIT"] |
||||
|
break |
||||
|
else: |
||||
|
ret = r |
||||
|
|
||||
|
return ret - CONFIG["FEE"] * 2 |
||||
|
|
||||
|
# ========================= |
||||
|
# 回测 |
||||
|
# ========================= |
||||
|
def backtest_real(df_all, params, position_ratio, init_cash=100000): |
||||
|
|
||||
|
cash = init_cash |
||||
|
equity = [] |
||||
|
|
||||
|
dates = sorted(df_all['date'].unique()) |
||||
|
|
||||
|
for i in range(len(dates) - params['hold_days']): |
||||
|
|
||||
|
today = dates[i] |
||||
|
|
||||
|
df_today = df_all[df_all['date'] <= today] |
||||
|
latest = df_today.groupby('code').tail(1) |
||||
|
|
||||
|
picks = latest[ |
||||
|
(latest['return_n'] > params['return_thresh']) & |
||||
|
(latest['vol_ratio'] > params['vol_thresh']) & |
||||
|
(latest['trend']) |
||||
|
].sort_values('score', ascending=False).head(3) |
||||
|
|
||||
|
if len(picks) == 0 or position_ratio == 0: |
||||
|
equity.append(cash) |
||||
|
continue |
||||
|
|
||||
|
position = cash * position_ratio |
||||
|
per_stock = position / len(picks) |
||||
|
|
||||
|
profit = 0 |
||||
|
|
||||
|
for _, row in picks.iterrows(): |
||||
|
code = row['code'] |
||||
|
|
||||
|
prices = df_all[ |
||||
|
(df_all['code'] == code) & |
||||
|
(df_all['date'] >= today) |
||||
|
].sort_values('date')['close'].values[:params['hold_days']] |
||||
|
|
||||
|
r = trade_return(prices) |
||||
|
profit += per_stock * r |
||||
|
|
||||
|
cash += profit |
||||
|
equity.append(cash) |
||||
|
|
||||
|
return equity |
||||
|
|
||||
|
# ========================= |
||||
|
# 调参 |
||||
|
# ========================= |
||||
|
def simple_grid(df): |
||||
|
best_score = -999 |
||||
|
best_params = None |
||||
|
|
||||
|
for r in [0.01, 0.02]: |
||||
|
for v in [1.1, 1.2]: |
||||
|
|
||||
|
params = { |
||||
|
"return_thresh": r, |
||||
|
"vol_thresh": v, |
||||
|
"hold_days": CONFIG["HOLD_DAYS"] |
||||
|
} |
||||
|
|
||||
|
equity = backtest_real(df, params, 0.4) |
||||
|
|
||||
|
if len(equity) < 10: |
||||
|
continue |
||||
|
|
||||
|
ret = (equity[-1] - equity[0]) / equity[0] |
||||
|
dd = calc_max_drawdown(equity) |
||||
|
|
||||
|
score = ret - dd |
||||
|
|
||||
|
if score > best_score: |
||||
|
best_score = score |
||||
|
best_params = params |
||||
|
|
||||
|
return best_params |
||||
|
|
||||
|
# ========================= |
||||
|
# 评估 |
||||
|
# ========================= |
||||
|
def calc_max_drawdown(equity): |
||||
|
peak = equity[0] |
||||
|
max_dd = 0 |
||||
|
|
||||
|
for x in equity: |
||||
|
if x > peak: |
||||
|
peak = x |
||||
|
dd = (peak - x) / peak |
||||
|
max_dd = max(max_dd, dd) |
||||
|
|
||||
|
return max_dd |
||||
|
|
||||
|
# ========================= |
||||
|
# Walk Forward(带指数) |
||||
|
# ========================= |
||||
|
def walk_forward(df_all, index_data): |
||||
|
|
||||
|
dates = sorted(df_all['date'].unique()) |
||||
|
|
||||
|
window = CONFIG["TRAIN_WINDOW"] |
||||
|
step = CONFIG["TEST_WINDOW"] |
||||
|
|
||||
|
final_equity = [] |
||||
|
|
||||
|
for start in range(0, len(dates) - window - step, step): |
||||
|
|
||||
|
train_dates = dates[start:start+window] |
||||
|
test_dates = dates[start+window:start+window+step] |
||||
|
|
||||
|
train_df = df_all[df_all['date'].isin(train_dates)] |
||||
|
test_df = df_all[df_all['date'].isin(test_dates)] |
||||
|
|
||||
|
best_params = simple_grid(train_df) |
||||
|
|
||||
|
# ===== 指数决定仓位 ===== |
||||
|
score = market_score(index_data) |
||||
|
position_ratio = market_strategy(score) |
||||
|
|
||||
|
if position_ratio == 0: |
||||
|
continue |
||||
|
|
||||
|
equity = backtest_real(test_df, best_params, position_ratio) |
||||
|
|
||||
|
if len(final_equity) == 0: |
||||
|
final_equity = equity |
||||
|
else: |
||||
|
base = final_equity[-1] |
||||
|
equity = [base + (x - equity[0]) for x in equity] |
||||
|
final_equity.extend(equity) |
||||
|
|
||||
|
return final_equity |
||||
|
|
||||
|
# ========================= |
||||
|
# 主程序 |
||||
|
# ========================= |
||||
|
def main(): |
||||
|
|
||||
|
init_db() |
||||
|
|
||||
|
codes = ["000001", "000002", "600519", "600036", "601318"] |
||||
|
|
||||
|
all_data = [] |
||||
|
|
||||
|
for code in codes: |
||||
|
df = get_data(code) |
||||
|
if df is None: |
||||
|
continue |
||||
|
|
||||
|
df['code'] = code |
||||
|
df = add_features(df) |
||||
|
df = score(df) |
||||
|
all_data.append(df) |
||||
|
|
||||
|
df_all = pd.concat(all_data) |
||||
|
|
||||
|
# ===== 指数一次性获取 ===== |
||||
|
index_data = {} |
||||
|
for name, code in CONFIG["INDEX_LIST"].items(): |
||||
|
df = get_index_data(code) |
||||
|
if df is not None: |
||||
|
index_data[name] = df |
||||
|
|
||||
|
equity = walk_forward(df_all, index_data) |
||||
|
|
||||
|
print("\n===== 实盘级回测 =====") |
||||
|
print("最终资金:", round(equity[-1], 2)) |
||||
|
|
||||
|
ret = (equity[-1] - equity[0]) / equity[0] |
||||
|
dd = calc_max_drawdown(equity) |
||||
|
|
||||
|
print("总收益:", round(ret, 2)) |
||||
|
print("最大回撤:", round(dd, 2)) |
||||
|
|
||||
|
|
||||
|
if __name__ == "__main__": |
||||
|
main() |
||||
@ -0,0 +1,26 @@ |
|||||
|
import sqlite3 |
||||
|
import pandas as pd |
||||
|
import os |
||||
|
|
||||
|
DB_PATH = "stock_data.db" |
||||
|
|
||||
|
|
||||
|
def init_db(): |
||||
|
conn = sqlite3.connect(DB_PATH) |
||||
|
cursor = conn.cursor() |
||||
|
|
||||
|
cursor.execute(""" |
||||
|
CREATE TABLE IF NOT EXISTS stock_daily ( |
||||
|
code TEXT, |
||||
|
date TEXT, |
||||
|
open REAL, |
||||
|
close REAL, |
||||
|
high REAL, |
||||
|
low REAL, |
||||
|
volume REAL, |
||||
|
PRIMARY KEY (code, date) |
||||
|
) |
||||
|
""") |
||||
|
|
||||
|
conn.commit() |
||||
|
conn.close() |
||||
Binary file not shown.
Loading…
Reference in new issue