Ver código fonte

最简单的回测页,上传csv,获取带涨幅的表格

caijynb 2 anos atrás
commit
f96a84ba3c
3 arquivos alterados com 113 adições e 0 exclusões
  1. 2 0
      .gitignore
  2. 84 0
      main.py
  3. 27 0
      templates/index.html

+ 2 - 0
.gitignore

@@ -0,0 +1,2 @@
+.idea/
+uploads/*

+ 84 - 0
main.py

@@ -0,0 +1,84 @@
+from flask import Flask, render_template, request, redirect, url_for
+import pandas as pd
+from werkzeug.utils import secure_filename
+import os
+from collections import Counter
+import tushare as ts
+
+# 初始化Tushare和Flask应用
+pro = ts.pro_api('a9b83bc559587ad8b391c631b5e4eb93bd304f918c8700a81b13604f')
+app = Flask(__name__)
+app.config['UPLOAD_FOLDER'] = 'uploads/'
+
+# 创建 uploads 文件夹,如果不存在
+if not os.path.exists(app.config['UPLOAD_FOLDER']):
+    os.makedirs(app.config['UPLOAD_FOLDER'])
+
+# 用户提供的函数
+def tsdate(date):
+    return date.replace('-', '')
+
+def tscode(code):
+    return code + '.SH' if code[0] == '6' else code + '.SZ'
+
+def get_trade_date_after(start_date, n):
+    start_date = start_date.replace('-', '')
+    cal = pro.trade_cal(exchange='SSE', start_date=start_date)
+    cal = cal[cal['is_open'] == 1]
+    trading_days = cal[cal['cal_date'] >= start_date]
+    trading_days = trading_days['cal_date'].tolist()[::-1]
+    return trading_days[n]
+
+def extract_code_date(file_path):
+    df = pd.read_csv(file_path)
+    code_date_df = df[['code', 'tradedate']].copy()
+    code_date_df['code'] = code_date_df['code'].astype(str).str.zfill(6)
+    result_list = code_date_df.values.tolist()
+    return result_list
+
+def get_pct(code, start_date, end_date):
+    daily_data = pro.daily(ts_code=tscode(code), start_date=tsdate(start_date), end_date=tsdate(end_date))
+    daily_data.set_index('trade_date', inplace=True)
+    daily_data.sort_index(inplace=True)
+    pct = round((daily_data.iloc[-1]['close'] - daily_data.iloc[0]['close']) / daily_data.iloc[0]['close'] * 100, 2)
+    return pct
+
+
+# 路由定义
+@app.route('/', methods=['GET', 'POST'])
+def index():
+    if request.method == 'POST':
+        file = request.files['file']
+        if file:
+            filename = secure_filename(file.filename)
+            filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
+            file.save(filepath)
+
+
+            # 读取CSV文件
+            code_date_list = extract_code_date(filepath)
+            pct_list = []
+            up_down_flags = []
+
+            for code, start_date in code_date_list:
+                end_date = get_trade_date_after(start_date, 10)
+                pct = get_pct(code, start_date, end_date)
+                pct_list.append(pct)
+                up_down_flags.append(1 if pct > 0 else 0)
+
+            counter = Counter(up_down_flags)
+            up_probability = counter[1] / len(up_down_flags) * 100
+
+            df = pd.read_csv(filepath)
+            df['code'] = df['code'].astype(str).str.zfill(6)
+            df['10_days_pct'] = pct_list
+
+            return render_template('index.html', 
+                                   tables=[df.to_html(classes='table table-striped')], 
+                                   titles=df.columns.values,
+                                   up_probability=up_probability)
+
+    return render_template('index.html')
+
+if __name__ == '__main__':
+    app.run(debug=True)

+ 27 - 0
templates/index.html

@@ -0,0 +1,27 @@
+
+<!DOCTYPE html>
+<html lang="en">
+<head>
+    <meta charset="UTF-8">
+    <title>股票数据分析</title>
+    <link href="https://cdn.bootcdn.net/ajax/libs/bootstrap/4.5.2/css/bootstrap.min.css" rel="stylesheet">
+</head>
+<body>
+
+<div class="container mt-5">
+    <h1>上传CSV文件</h1>
+    <form action="/" method="post" enctype="multipart/form-data">
+        <input type="file" name="file">
+        <input type="submit" value="上传">
+    </form>
+
+    {% if tables %}
+        <h2>分析结果</h2>
+        {% for table in tables %}
+            {{ table|safe }}
+        {% endfor %}
+    {% endif %}
+</div>
+
+</body>
+</html>