from flask import Flask, render_template, request, redirect, url_for import pandas as pd from werkzeug.utils import secure_filename import os from collections import Counter import tushare as ts # 初始化Tushare和Flask应用 pro = ts.pro_api('a9b83bc559587ad8b391c631b5e4eb93bd304f918c8700a81b13604f') app = Flask(__name__) app.config['UPLOAD_FOLDER'] = 'uploads/' # 创建 uploads 文件夹,如果不存在 if not os.path.exists(app.config['UPLOAD_FOLDER']): os.makedirs(app.config['UPLOAD_FOLDER']) # 用户提供的函数 def tsdate(date): return date.replace('-', '') def tscode(code): return code + '.SH' if code[0] == '6' else code + '.SZ' def get_trade_date_after(start_date, n): start_date = start_date.replace('-', '') cal = pro.trade_cal(exchange='SSE', start_date=start_date) cal = cal[cal['is_open'] == 1] trading_days = cal[cal['cal_date'] >= start_date] trading_days = trading_days['cal_date'].tolist()[::-1] return trading_days[n] def extract_code_date(file_path): df = pd.read_csv(file_path) code_date_df = df[['code', 'tradedate']].copy() code_date_df['code'] = code_date_df['code'].astype(str).str.zfill(6) result_list = code_date_df.values.tolist() return result_list def get_pct(code, start_date, end_date): daily_data = pro.daily(ts_code=tscode(code), start_date=tsdate(start_date), end_date=tsdate(end_date)) daily_data.set_index('trade_date', inplace=True) daily_data.sort_index(inplace=True) pct = round((daily_data.iloc[-1]['close'] - daily_data.iloc[0]['close']) / daily_data.iloc[0]['close'] * 100, 2) return pct # 路由定义 @app.route('/') def index(): # 基础页面只负责渲染主页 return render_template('base_index.html') @app.route('/simple_backtest', methods=['GET', 'POST']) def simple_backtest(): if request.method == 'POST': file = request.files['file'] if file: filename = secure_filename(file.filename) filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename) file.save(filepath) # 以下逻辑与原来的index函数中的POST处理逻辑相同 code_date_list = extract_code_date(filepath) pct_list = [] up_down_flags = [] for code, start_date in code_date_list: end_date = get_trade_date_after(start_date, 10) pct = get_pct(code, start_date, end_date) pct_list.append(pct) up_down_flags.append(1 if pct > 0 else 0) counter = Counter(up_down_flags) up_probability = counter[1] / len(up_down_flags) * 100 df = pd.read_csv(filepath) df['code'] = df['code'].astype(str).str.zfill(6) df['10_days_pct'] = pct_list # 注意这里渲染的是'simple_backtest.html' return render_template('simple_backtest.html', tables=[df.to_html(classes='table table-striped')], titles=df.columns.values, up_probability=up_probability) # 对于GET请求,渲染'simple_backtest.html'但不包含表格数据 return render_template('simple_backtest.html') if __name__ == '__main__': app.run(debug=True)