main.py 3.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. from flask import Flask, render_template, request, redirect, url_for
  2. import pandas as pd
  3. from werkzeug.utils import secure_filename
  4. import os
  5. from collections import Counter
  6. import tushare as ts
  7. # 初始化Tushare和Flask应用
  8. pro = ts.pro_api('a9b83bc559587ad8b391c631b5e4eb93bd304f918c8700a81b13604f')
  9. app = Flask(__name__)
  10. app.config['UPLOAD_FOLDER'] = 'uploads/'
  11. # 创建 uploads 文件夹,如果不存在
  12. if not os.path.exists(app.config['UPLOAD_FOLDER']):
  13. os.makedirs(app.config['UPLOAD_FOLDER'])
  14. # 用户提供的函数
  15. def tsdate(date):
  16. return date.replace('-', '')
  17. def tscode(code):
  18. return code + '.SH' if code[0] == '6' else code + '.SZ'
  19. def get_trade_date_after(start_date, n):
  20. start_date = start_date.replace('-', '')
  21. cal = pro.trade_cal(exchange='SSE', start_date=start_date)
  22. cal = cal[cal['is_open'] == 1]
  23. trading_days = cal[cal['cal_date'] >= start_date]
  24. trading_days = trading_days['cal_date'].tolist()[::-1]
  25. return trading_days[n]
  26. def extract_code_date(file_path):
  27. df = pd.read_csv(file_path)
  28. code_date_df = df[['code', 'tradedate']].copy()
  29. code_date_df['code'] = code_date_df['code'].astype(str).str.zfill(6)
  30. result_list = code_date_df.values.tolist()
  31. return result_list
  32. def get_pct(code, start_date, end_date):
  33. daily_data = pro.daily(ts_code=tscode(code), start_date=tsdate(start_date), end_date=tsdate(end_date))
  34. daily_data.set_index('trade_date', inplace=True)
  35. daily_data.sort_index(inplace=True)
  36. pct = round((daily_data.iloc[-1]['close'] - daily_data.iloc[0]['close']) / daily_data.iloc[0]['close'] * 100, 2)
  37. return pct
  38. # 路由定义
  39. @app.route('/')
  40. def index():
  41. # 基础页面只负责渲染主页
  42. return render_template('base_index.html')
  43. @app.route('/simple_backtest', methods=['GET', 'POST'])
  44. def simple_backtest():
  45. if request.method == 'POST':
  46. file = request.files['file']
  47. if file:
  48. filename = secure_filename(file.filename)
  49. filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
  50. file.save(filepath)
  51. # 以下逻辑与原来的index函数中的POST处理逻辑相同
  52. code_date_list = extract_code_date(filepath)
  53. pct_list = []
  54. up_down_flags = []
  55. for code, start_date in code_date_list:
  56. end_date = get_trade_date_after(start_date, 10)
  57. pct = get_pct(code, start_date, end_date)
  58. pct_list.append(pct)
  59. up_down_flags.append(1 if pct > 0 else 0)
  60. counter = Counter(up_down_flags)
  61. up_probability = counter[1] / len(up_down_flags) * 100
  62. df = pd.read_csv(filepath)
  63. df['code'] = df['code'].astype(str).str.zfill(6)
  64. df['10_days_pct'] = pct_list
  65. # 注意这里渲染的是'simple_backtest.html'
  66. return render_template('simple_backtest.html',
  67. tables=[df.to_html(classes='table table-striped')],
  68. titles=df.columns.values,
  69. up_probability=up_probability)
  70. # 对于GET请求,渲染'simple_backtest.html'但不包含表格数据
  71. return render_template('simple_backtest.html')
  72. if __name__ == '__main__':
  73. app.run(debug=True)