diff --git a/ak_share.py b/ak_share.py new file mode 100644 index 0000000..b4a31a5 --- /dev/null +++ b/ak_share.py @@ -0,0 +1,427 @@ +# ========================================= +# 工业级AI选股系统(最终整合版) +# 指数择时 + 动态仓位 + 回测 + 调参 +# ========================================= + +import sqlite3 +import akshare as ak +import pandas as pd + +DB_PATH = "stock_data.db" + +# ========================================= +# 全局参数配置 +# ========================================= +CONFIG = { + + # ===== 数据 ===== + "START_DATE": "20230101", + "END_DATE": "20240401", + + # ===== 均线 ===== + "MA_SHORT": 5, + "MA_MID": 20, + "MA_LONG": 60, + + # ===== 特征 ===== + "RETURN_WINDOW": 3, + "VOL_WINDOW": 5, + "VOLATILITY_WINDOW": 10, + + # ===== 权重 ===== + "W_MOMENTUM": 0.4, + "W_TREND": 0.3, + "W_VOLUME": 0.2, + "W_VOLATILITY": -0.1, + + # ===== 策略 ===== + "RETURN_THRESH": 0.02, + "VOL_THRESH": 1.2, + + "HOLD_DAYS": 3, + "STOP_LOSS": -0.05, + "TAKE_PROFIT": 0.1, + + # ===== 成本 ===== + "FEE": 0.001, + "SLIPPAGE": 0.001, + + # ===== Walk Forward ===== + "TRAIN_WINDOW": 60, + "TEST_WINDOW": 20, + + # ===== 指数 ===== + "INDEX_LIST": { + "hs300": "000300", + "zz500": "000905", + "cyb": "399006", + } +} + +# ========================= +# 数据库 +# ========================= +def init_db(): + conn = sqlite3.connect(DB_PATH) + conn.execute(""" + CREATE TABLE IF NOT EXISTS stock_daily ( + code TEXT, + date TEXT, + open REAL, + close REAL, + high REAL, + low REAL, + volume REAL, + PRIMARY KEY (code, date) + ) + """) + conn.close() + + +def load_from_db(code): + conn = sqlite3.connect(DB_PATH) + df = pd.read_sql_query( + f"SELECT * FROM stock_daily WHERE code='{code}'", conn + ) + conn.close() + return df.sort_values("date") + + +def save_to_db(df, code): + conn = sqlite3.connect(DB_PATH) + df['code'] = code + df.to_sql("stock_daily", conn, if_exists="append", index=False) + conn.close() + +# ========================= +# 数据获取 +# ========================= +def fetch_data(code): + try: + df = ak.stock_zh_a_hist( + symbol=code, + period="daily", + start_date=CONFIG["START_DATE"], + end_date=CONFIG["END_DATE"], + adjust="qfq" + ) + except: + df = ak.stock_zh_a_daily(symbol=code) + + if df is None or df.empty: + return None + + df = df.rename(columns={ + "日期": "date", + "开盘": "open", + "收盘": "close", + "最高": "high", + "最低": "low", + "成交量": "volume" + }) + + if "date" not in df.columns: + return None + + return df[["date", "open", "close", "high", "low", "volume"]] + + +def get_data(code): + try: + df = load_from_db(code) + if len(df) > 30: + return df + except: + pass + + df = fetch_data(code) + if df is not None: + save_to_db(df, code) + return df + +# ========================= +# 特征工程 +# ========================= +def add_features(df): + + df['return_n'] = df['close'].pct_change(CONFIG["RETURN_WINDOW"]) + + df['ma_short'] = df['close'].rolling(CONFIG["MA_SHORT"]).mean() + df['ma_mid'] = df['close'].rolling(CONFIG["MA_MID"]).mean() + df['ma_long'] = df['close'].rolling(CONFIG["MA_LONG"]).mean() + + df['vol_ratio'] = df['volume'] / df['volume'].rolling(CONFIG["VOL_WINDOW"]).mean() + df['volatility'] = df['close'].pct_change().rolling(CONFIG["VOLATILITY_WINDOW"]).std() + + df['trend'] = df['ma_mid'] > df['ma_long'] + + return df.dropna() + + +def score(df): + df['score'] = ( + CONFIG["W_MOMENTUM"] * df['return_n'] + + CONFIG["W_TREND"] * ((df['ma_short'] - df['ma_mid']) / df['ma_mid']) + + CONFIG["W_VOLUME"] * df['vol_ratio'] + + CONFIG["W_VOLATILITY"] * df['volatility'] + ) + return df + +# ========================= +# 指数模块 +# ========================= +def get_index_data(code): + try: + df = ak.index_zh_a_hist( + symbol=code, + period="daily", + start_date=CONFIG["START_DATE"], + end_date=CONFIG["END_DATE"] + ) + except: + return None + + if df is None or df.empty: + return None + + df = df.rename(columns={"日期": "date", "收盘": "close"}) + return df[['date', 'close']] + + +def index_trend(df): + ma20 = df['close'].rolling(CONFIG["MA_MID"]).mean() + ma60 = df['close'].rolling(CONFIG["MA_LONG"]).mean() + return ma20.iloc[-1] > ma60.iloc[-1] + + +def market_score(index_data): + score = 0 + + if index_trend(index_data["hs300"]): + score += 2 + else: + score -= 2 + + if index_trend(index_data["zz500"]): + score += 2 + + if index_trend(index_data["cyb"]): + score += 1 + + return score + + +def market_strategy(score): + if score >= 4: + return 0.7 + elif score >= 2: + return 0.4 + elif score >= 0: + return 0.2 + else: + return 0 + +# ========================= +# 交易函数 +# ========================= +def trade_return(prices): + if len(prices) == 0: + return 0 + + entry = prices[0] * (1 + CONFIG["SLIPPAGE"]) + ret = 0 + + for p in prices: + exit_price = p * (1 - CONFIG["SLIPPAGE"]) + r = (exit_price - entry) / entry + + if r <= CONFIG["STOP_LOSS"]: + ret = CONFIG["STOP_LOSS"] + break + elif r >= CONFIG["TAKE_PROFIT"]: + ret = CONFIG["TAKE_PROFIT"] + break + else: + ret = r + + return ret - CONFIG["FEE"] * 2 + +# ========================= +# 回测 +# ========================= +def backtest_real(df_all, params, position_ratio, init_cash=100000): + + cash = init_cash + equity = [] + + dates = sorted(df_all['date'].unique()) + + for i in range(len(dates) - params['hold_days']): + + today = dates[i] + + df_today = df_all[df_all['date'] <= today] + latest = df_today.groupby('code').tail(1) + + picks = latest[ + (latest['return_n'] > params['return_thresh']) & + (latest['vol_ratio'] > params['vol_thresh']) & + (latest['trend']) + ].sort_values('score', ascending=False).head(3) + + if len(picks) == 0 or position_ratio == 0: + equity.append(cash) + continue + + position = cash * position_ratio + per_stock = position / len(picks) + + profit = 0 + + for _, row in picks.iterrows(): + code = row['code'] + + prices = df_all[ + (df_all['code'] == code) & + (df_all['date'] >= today) + ].sort_values('date')['close'].values[:params['hold_days']] + + r = trade_return(prices) + profit += per_stock * r + + cash += profit + equity.append(cash) + + return equity + +# ========================= +# 调参 +# ========================= +def simple_grid(df): + best_score = -999 + best_params = None + + for r in [0.01, 0.02]: + for v in [1.1, 1.2]: + + params = { + "return_thresh": r, + "vol_thresh": v, + "hold_days": CONFIG["HOLD_DAYS"] + } + + equity = backtest_real(df, params, 0.4) + + if len(equity) < 10: + continue + + ret = (equity[-1] - equity[0]) / equity[0] + dd = calc_max_drawdown(equity) + + score = ret - dd + + if score > best_score: + best_score = score + best_params = params + + return best_params + +# ========================= +# 评估 +# ========================= +def calc_max_drawdown(equity): + peak = equity[0] + max_dd = 0 + + for x in equity: + if x > peak: + peak = x + dd = (peak - x) / peak + max_dd = max(max_dd, dd) + + return max_dd + +# ========================= +# Walk Forward(带指数) +# ========================= +def walk_forward(df_all, index_data): + + dates = sorted(df_all['date'].unique()) + + window = CONFIG["TRAIN_WINDOW"] + step = CONFIG["TEST_WINDOW"] + + final_equity = [] + + for start in range(0, len(dates) - window - step, step): + + train_dates = dates[start:start+window] + test_dates = dates[start+window:start+window+step] + + train_df = df_all[df_all['date'].isin(train_dates)] + test_df = df_all[df_all['date'].isin(test_dates)] + + best_params = simple_grid(train_df) + + # ===== 指数决定仓位 ===== + score = market_score(index_data) + position_ratio = market_strategy(score) + + if position_ratio == 0: + continue + + equity = backtest_real(test_df, best_params, position_ratio) + + if len(final_equity) == 0: + final_equity = equity + else: + base = final_equity[-1] + equity = [base + (x - equity[0]) for x in equity] + final_equity.extend(equity) + + return final_equity + +# ========================= +# 主程序 +# ========================= +def main(): + + init_db() + + codes = ["000001", "000002", "600519", "600036", "601318"] + + all_data = [] + + for code in codes: + df = get_data(code) + if df is None: + continue + + df['code'] = code + df = add_features(df) + df = score(df) + all_data.append(df) + + df_all = pd.concat(all_data) + + # ===== 指数一次性获取 ===== + index_data = {} + for name, code in CONFIG["INDEX_LIST"].items(): + df = get_index_data(code) + if df is not None: + index_data[name] = df + + equity = walk_forward(df_all, index_data) + + print("\n===== 实盘级回测 =====") + print("最终资金:", round(equity[-1], 2)) + + ret = (equity[-1] - equity[0]) / equity[0] + dd = calc_max_drawdown(equity) + + print("总收益:", round(ret, 2)) + print("最大回撤:", round(dd, 2)) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/data_engine.py b/data_engine.py new file mode 100644 index 0000000..ecdb03a --- /dev/null +++ b/data_engine.py @@ -0,0 +1,26 @@ +import sqlite3 +import pandas as pd +import os + +DB_PATH = "stock_data.db" + + +def init_db(): + conn = sqlite3.connect(DB_PATH) + cursor = conn.cursor() + + cursor.execute(""" + CREATE TABLE IF NOT EXISTS stock_daily ( + code TEXT, + date TEXT, + open REAL, + close REAL, + high REAL, + low REAL, + volume REAL, + PRIMARY KEY (code, date) + ) + """) + + conn.commit() + conn.close() \ No newline at end of file diff --git a/stock_data.db b/stock_data.db new file mode 100644 index 0000000..107f41e Binary files /dev/null and b/stock_data.db differ