Browse Source

vue 和 main 合并

master
zhanglei 1 month ago
parent
commit
10010fc56f
  1. 427
      ak_share.py
  2. 26
      data_engine.py
  3. BIN
      stock_data.db

427
ak_share.py

@ -0,0 +1,427 @@
# =========================================
# 工业级AI选股系统(最终整合版)
# 指数择时 + 动态仓位 + 回测 + 调参
# =========================================
import sqlite3
import akshare as ak
import pandas as pd
DB_PATH = "stock_data.db"
# =========================================
# 全局参数配置
# =========================================
CONFIG = {
# ===== 数据 =====
"START_DATE": "20230101",
"END_DATE": "20240401",
# ===== 均线 =====
"MA_SHORT": 5,
"MA_MID": 20,
"MA_LONG": 60,
# ===== 特征 =====
"RETURN_WINDOW": 3,
"VOL_WINDOW": 5,
"VOLATILITY_WINDOW": 10,
# ===== 权重 =====
"W_MOMENTUM": 0.4,
"W_TREND": 0.3,
"W_VOLUME": 0.2,
"W_VOLATILITY": -0.1,
# ===== 策略 =====
"RETURN_THRESH": 0.02,
"VOL_THRESH": 1.2,
"HOLD_DAYS": 3,
"STOP_LOSS": -0.05,
"TAKE_PROFIT": 0.1,
# ===== 成本 =====
"FEE": 0.001,
"SLIPPAGE": 0.001,
# ===== Walk Forward =====
"TRAIN_WINDOW": 60,
"TEST_WINDOW": 20,
# ===== 指数 =====
"INDEX_LIST": {
"hs300": "000300",
"zz500": "000905",
"cyb": "399006",
}
}
# =========================
# 数据库
# =========================
def init_db():
conn = sqlite3.connect(DB_PATH)
conn.execute("""
CREATE TABLE IF NOT EXISTS stock_daily (
code TEXT,
date TEXT,
open REAL,
close REAL,
high REAL,
low REAL,
volume REAL,
PRIMARY KEY (code, date)
)
""")
conn.close()
def load_from_db(code):
conn = sqlite3.connect(DB_PATH)
df = pd.read_sql_query(
f"SELECT * FROM stock_daily WHERE code='{code}'", conn
)
conn.close()
return df.sort_values("date")
def save_to_db(df, code):
conn = sqlite3.connect(DB_PATH)
df['code'] = code
df.to_sql("stock_daily", conn, if_exists="append", index=False)
conn.close()
# =========================
# 数据获取
# =========================
def fetch_data(code):
try:
df = ak.stock_zh_a_hist(
symbol=code,
period="daily",
start_date=CONFIG["START_DATE"],
end_date=CONFIG["END_DATE"],
adjust="qfq"
)
except:
df = ak.stock_zh_a_daily(symbol=code)
if df is None or df.empty:
return None
df = df.rename(columns={
"日期": "date",
"开盘": "open",
"收盘": "close",
"最高": "high",
"最低": "low",
"成交量": "volume"
})
if "date" not in df.columns:
return None
return df[["date", "open", "close", "high", "low", "volume"]]
def get_data(code):
try:
df = load_from_db(code)
if len(df) > 30:
return df
except:
pass
df = fetch_data(code)
if df is not None:
save_to_db(df, code)
return df
# =========================
# 特征工程
# =========================
def add_features(df):
df['return_n'] = df['close'].pct_change(CONFIG["RETURN_WINDOW"])
df['ma_short'] = df['close'].rolling(CONFIG["MA_SHORT"]).mean()
df['ma_mid'] = df['close'].rolling(CONFIG["MA_MID"]).mean()
df['ma_long'] = df['close'].rolling(CONFIG["MA_LONG"]).mean()
df['vol_ratio'] = df['volume'] / df['volume'].rolling(CONFIG["VOL_WINDOW"]).mean()
df['volatility'] = df['close'].pct_change().rolling(CONFIG["VOLATILITY_WINDOW"]).std()
df['trend'] = df['ma_mid'] > df['ma_long']
return df.dropna()
def score(df):
df['score'] = (
CONFIG["W_MOMENTUM"] * df['return_n'] +
CONFIG["W_TREND"] * ((df['ma_short'] - df['ma_mid']) / df['ma_mid']) +
CONFIG["W_VOLUME"] * df['vol_ratio'] +
CONFIG["W_VOLATILITY"] * df['volatility']
)
return df
# =========================
# 指数模块
# =========================
def get_index_data(code):
try:
df = ak.index_zh_a_hist(
symbol=code,
period="daily",
start_date=CONFIG["START_DATE"],
end_date=CONFIG["END_DATE"]
)
except:
return None
if df is None or df.empty:
return None
df = df.rename(columns={"日期": "date", "收盘": "close"})
return df[['date', 'close']]
def index_trend(df):
ma20 = df['close'].rolling(CONFIG["MA_MID"]).mean()
ma60 = df['close'].rolling(CONFIG["MA_LONG"]).mean()
return ma20.iloc[-1] > ma60.iloc[-1]
def market_score(index_data):
score = 0
if index_trend(index_data["hs300"]):
score += 2
else:
score -= 2
if index_trend(index_data["zz500"]):
score += 2
if index_trend(index_data["cyb"]):
score += 1
return score
def market_strategy(score):
if score >= 4:
return 0.7
elif score >= 2:
return 0.4
elif score >= 0:
return 0.2
else:
return 0
# =========================
# 交易函数
# =========================
def trade_return(prices):
if len(prices) == 0:
return 0
entry = prices[0] * (1 + CONFIG["SLIPPAGE"])
ret = 0
for p in prices:
exit_price = p * (1 - CONFIG["SLIPPAGE"])
r = (exit_price - entry) / entry
if r <= CONFIG["STOP_LOSS"]:
ret = CONFIG["STOP_LOSS"]
break
elif r >= CONFIG["TAKE_PROFIT"]:
ret = CONFIG["TAKE_PROFIT"]
break
else:
ret = r
return ret - CONFIG["FEE"] * 2
# =========================
# 回测
# =========================
def backtest_real(df_all, params, position_ratio, init_cash=100000):
cash = init_cash
equity = []
dates = sorted(df_all['date'].unique())
for i in range(len(dates) - params['hold_days']):
today = dates[i]
df_today = df_all[df_all['date'] <= today]
latest = df_today.groupby('code').tail(1)
picks = latest[
(latest['return_n'] > params['return_thresh']) &
(latest['vol_ratio'] > params['vol_thresh']) &
(latest['trend'])
].sort_values('score', ascending=False).head(3)
if len(picks) == 0 or position_ratio == 0:
equity.append(cash)
continue
position = cash * position_ratio
per_stock = position / len(picks)
profit = 0
for _, row in picks.iterrows():
code = row['code']
prices = df_all[
(df_all['code'] == code) &
(df_all['date'] >= today)
].sort_values('date')['close'].values[:params['hold_days']]
r = trade_return(prices)
profit += per_stock * r
cash += profit
equity.append(cash)
return equity
# =========================
# 调参
# =========================
def simple_grid(df):
best_score = -999
best_params = None
for r in [0.01, 0.02]:
for v in [1.1, 1.2]:
params = {
"return_thresh": r,
"vol_thresh": v,
"hold_days": CONFIG["HOLD_DAYS"]
}
equity = backtest_real(df, params, 0.4)
if len(equity) < 10:
continue
ret = (equity[-1] - equity[0]) / equity[0]
dd = calc_max_drawdown(equity)
score = ret - dd
if score > best_score:
best_score = score
best_params = params
return best_params
# =========================
# 评估
# =========================
def calc_max_drawdown(equity):
peak = equity[0]
max_dd = 0
for x in equity:
if x > peak:
peak = x
dd = (peak - x) / peak
max_dd = max(max_dd, dd)
return max_dd
# =========================
# Walk Forward(带指数)
# =========================
def walk_forward(df_all, index_data):
dates = sorted(df_all['date'].unique())
window = CONFIG["TRAIN_WINDOW"]
step = CONFIG["TEST_WINDOW"]
final_equity = []
for start in range(0, len(dates) - window - step, step):
train_dates = dates[start:start+window]
test_dates = dates[start+window:start+window+step]
train_df = df_all[df_all['date'].isin(train_dates)]
test_df = df_all[df_all['date'].isin(test_dates)]
best_params = simple_grid(train_df)
# ===== 指数决定仓位 =====
score = market_score(index_data)
position_ratio = market_strategy(score)
if position_ratio == 0:
continue
equity = backtest_real(test_df, best_params, position_ratio)
if len(final_equity) == 0:
final_equity = equity
else:
base = final_equity[-1]
equity = [base + (x - equity[0]) for x in equity]
final_equity.extend(equity)
return final_equity
# =========================
# 主程序
# =========================
def main():
init_db()
codes = ["000001", "000002", "600519", "600036", "601318"]
all_data = []
for code in codes:
df = get_data(code)
if df is None:
continue
df['code'] = code
df = add_features(df)
df = score(df)
all_data.append(df)
df_all = pd.concat(all_data)
# ===== 指数一次性获取 =====
index_data = {}
for name, code in CONFIG["INDEX_LIST"].items():
df = get_index_data(code)
if df is not None:
index_data[name] = df
equity = walk_forward(df_all, index_data)
print("\n===== 实盘级回测 =====")
print("最终资金:", round(equity[-1], 2))
ret = (equity[-1] - equity[0]) / equity[0]
dd = calc_max_drawdown(equity)
print("总收益:", round(ret, 2))
print("最大回撤:", round(dd, 2))
if __name__ == "__main__":
main()

26
data_engine.py

@ -0,0 +1,26 @@
import sqlite3
import pandas as pd
import os
DB_PATH = "stock_data.db"
def init_db():
conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor()
cursor.execute("""
CREATE TABLE IF NOT EXISTS stock_daily (
code TEXT,
date TEXT,
open REAL,
close REAL,
high REAL,
low REAL,
volume REAL,
PRIMARY KEY (code, date)
)
""")
conn.commit()
conn.close()

BIN
stock_data.db

Binary file not shown.
Loading…
Cancel
Save