import tushare as ts
import db
import history
import my_thread
import time


# 初始化ts
pro = ts.pro_api(token='233302841e61938a0e1b77e74dd6fe86703e0a5ddece3544569c9017')

t1 = time.time()


# 排序参数，取第2位排序
def take_second(elem):
    return elem[1]


# 排序参数，取第4位排序
def take_four(elem):
    return elem[3]


# 排序参数，取第4位排序
def take_five(elem):
    return elem[4]


# 判断今天是否交易日
def check_trade_day():
    toDay = time.strftime('%Y%m%d', time.localtime())
    h = int(time.strftime('%H', time.localtime()))
    m = int(time.strftime('%M', time.localtime()))
    cal = pro.query('trade_cal', start_date=toDay, end_date=toDay)
    for i in cal.index:
        if cal.loc[i][2] == 1:
            # 过滤9点半之前
            if h <= 9 and m < 30:
                return False
            # 过滤11点半之后
            if h == 11 and m > 30:
                return False
            return True
        else:
            return False


# 获取正常运行的行业分类列表
def get_running_list():

    list = pro.stock_basic(exchange='', list_status='L', fields='ts_code,symbol,name,area,industry,list_date')

    mysql.save_cate_table(list)
    return list


# 获取实时行情
def get_now_price():
    # 实时行情API
    today = ts.get_today_all()
    # 保存到实时行情表
    mysql.save_now_table(today)

    # 取出涨幅接近涨停的票
    for i in today.index:
        d = today.loc[i]
        if d['changepercent'] > 9.6:
            # print(d['code'] + " " + d['name'])
            v = {'code': d['code'], 'name': d['name'], 'percent': d['changepercent']}
            maxData.append(v)

    # 给接近涨停的票加上[行业分类]参数
    for i in range(len(maxData)):
        for j in category.index:
            if maxData[i]['code'] == category.loc[j]['symbol']:
                maxData[i]['cate'] = category.loc[j]['industry']

    print(maxData)


# 取出接近涨停票中的cate分类
def get_hot_cate():
    # 取出接近涨停票中的cate分类
    # 并计算出现每个分类的次数
    # 目的是取出[热门行业]
    cates = {}
    for i in range(len(maxData)):
        key = maxData[i].get('cate')
        if key in cates:
            num = cates.get(key) + 1
        else:
            num = 1
        cates[key] = num

    # 分类排序并取出排名
    sorts = []
    cate_list = list(cates.items())
    cate_list.sort(key=take_second, reverse=True)

    # 取行业排名前6
    for i in range(0, 6):
        if cate_list[i][0] is not None:
            sorts.append(cate_list[i][0])

    print('\n')
    print(sorts)

    query_sort_history(sorts)


# 查询同一行业里最近180天的高位数据
def query_sort_history(sorts):
    # 清空历史表
    mysql.clear_history()

    tt = []
    # 处理每个行业的数据
    for i in range(len(sorts)):
        thread1 = my_thread.MyThread('Thread-'+str(i), sorts[i])
        thread1.start()

        tt.append(thread1)

        # 取出同行业全部票
        # stocks = mysql.getListForCate(sorts[i])
        # stockPercent = []
        # for i in range(len(stocks)):
        #     code = stocks[i][0]
        #     # 查询票当前涨幅
        #     p = mysql.getPercent(code)
        #     for j in p:
        #         s = (float(j[0]),)
        #         a = stocks[i] + s
        #         stockPercent.append(a)
        # # 根据涨幅排序
        # stockPercent.sort(key=take_four, reverse=True)
        # # print(stockPercent)
        #
        # for z in range(len(stockPercent)):
        #     if 3 <= float(stockPercent[z][3]):
        #         # print(stockPercent[z])
        #         # 查询单个的180天历史高位数据，并保存到数据库
        #         toDay = time.strftime('%Y-%m-%d', time.localtime())
        #         history.query_stock_history(mysql, toDay, stockPercent[z][0],
        #                                     stockPercent[z][1],
        #                                     stockPercent[z][2])

    # 等待所有线程执行完成
    for i in tt:
        i.join()

    compile_stock(sorts)


# 生成比对结果
def compile_stock(sorts):
    current_time = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))
    # 准备写文件
    f = open('./code.txt', 'a+')
    f.write("\t\n\n")
    f.write(str(sorts))
    f.write("\t\n")
    f.write(current_time)

    # 处理每个行业的数据
    for x in range(len(sorts)):
        stocks = mysql.getListForCate(sorts[x])
        stockPercent = []
        for i in range(len(stocks)):
            code = stocks[i][0]
            p = mysql.getPercent(code)
            for j in p:
                s = (float(j[0]),)
                a = stocks[i] + s
                # print(a)
                stockPercent.append(a)
        # 根据涨幅排序
        stockPercent.sort(key=take_four, reverse=True)
        # print(stockPercent)

        # 今日龙头
        maxPercent = []
        for i in range(len(stockPercent)):
            # 取出今天行业的领涨龙头
            if 9.8 <= float(stockPercent[i][3]):
                maxPercent.append(stockPercent[i])

        # print(maxPercent)

        allStr = ''
        for i in maxPercent:
            # 查询龙头历史数据
            days = mysql.selectDay(i[0])
            for j in days:
                result = mysql.selectRelationStocks(j[4], i[2])
                # print(str(result))
                allStr += str(result)
        # print(allStr)

        # 处理字符
        allStr = allStr.replace("(", "")
        allStr = allStr.replace(")", "")
        allStr = allStr.replace("'", "")
        allStr = allStr.replace(",,", ",")
        allStr = allStr.replace(" ", "")
        # print(allStr)

        # 移除重复字符
        sortList = []
        for i in allStr.split(','):
            if i != '':
                d = (i, allStr.count(i))
                if sortList.count(d) <= 0:
                    sortList.append(d)
        sortList.sort(key=take_second, reverse=True)
        # print(sortList)

        result = []
        for i in sortList:
            # 查询历史中高位出现的次数
            num = mysql.selectCount(i[0])[0]
            if num >= 3:
                p = 0
                price = 0
                n = ''

                # 从保存的实时信息中取出
                df = mysql.get_now(i[0])
                for j in df:
                    p = j[3]
                    price = j[4]
                    n = j[2]

                # 过滤百分比太高和太低的数据
                if 1 <= float(p) <= 7:
                    # 查询15天内的上升趋势
                    toDay = time.strftime('%Y-%m-%d', time.localtime())
                    t = history.query_month_history(i[0], toDay)
                    if t[0]:
                        value = (i[0] + ' ' + n, price, str(p) + "%", t[1], num)
                        result.append(value)
                        # 保存结果到数据库
                        mysql.save_result(current_time, sorts[x], str(value))

        result.sort(key=take_five, reverse=True)

        print(sorts[x])
        print(result)

        f.write("\t\n" + sorts[x] + " => " + str(result))

    f.close()



if check_trade_day() is False:
    print('~~ 非交易时间 ~~')
    exit()

# 初始化数据库操作类
mysql = db.MY_SQL()

# 行业分类API()
category = get_running_list()

# 接近涨停的票数据
maxData = []

# 获取实时行情
get_now_price()

# 取出接近涨停票中的cate分类
get_hot_cate()


mysql.closeDB()

print(time.time() - t1)

