import argparse
import time
import logging
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import math
from data.dataset import  *
from model.model_vgg16 import *

# 设置日志文件路径
log_file = "E:\mypaper\code\jishu_MLP\model\log\\test_log\\log.log"
writer = SummaryWriter('E:\mypaper\code\jishu_MLP\model\log\\test_log\\loss.log')

# 配置日志格式
logging.basicConfig(filename=log_file, level=logging.INFO,
                    format='%(asctime)s - %(levelname)s - %(message)s')



workers = 4
save_file = "weight_best.pth"
save_sumary = "vgg16_2023_9_"


console = logging.StreamHandler()
console.setLevel(logging.INFO)
logging.getLogger('').addHandler(console)

def log_info(message):
    logging.info(message)
def main():
    print("4layer_2021_vgg16_1")
    parser = argparse.ArgumentParser('Set parameters for training ', add_help=False)
    parser.add_argument('--device', default="cuda", type=str)
    parser.add_argument('--batch_size', default=4, type=int)
    parser.add_argument('--run_device', default='cs', type=str)
    parser.add_argument("--epoch", default=1000, type=int)
    parser.add_argument("--img_size", default=256, type=int)
    #parser.add_argument("--lr", default=0.01, type=float)
    parser.add_argument("--lr", default=0.0001, type=float)
    args = parser.parse_args()

    # 定义训练参数
    device = args.device
    learning_rate = args.lr
    # 网络超参数
    batch_size = args.batch_size
    epoch = args.epoch
    img_size = args.img_size

    if args.run_device == 'cs':
        data_root = 'E:\mypaper\My_experiment\MLP_jishu\jishu\\'
        train_root = data_root + 'fused\\'
        train_ann = data_root + 'label\\fus281.csv'
        val_root = data_root + 'fused\\'
        val_ann = data_root + 'label\\fus281.csv'
        test_data_root = 'E:\mypaper\My_experiment\MLP_jishu\jishu\\train_fuse_new_images\\'
        test_root = test_data_root + 'train_set\\'
        test_ann = test_data_root+ 'label\\fus_new_118.csv'
    else:
        data_root = "D:\datasets\gwhd_2021\\"
        train_root = data_root + 'train\\'
        train_ann = data_root + 'train.csv'
        val_root = data_root + 'val\\'
        val_ann = data_root + 'val.csv'
        test_root = data_root + 'test\\'
        test_ann = data_root + 'test.csv'

    # 准备数据集
    train_dataset = dataset.Countgwhd(img_path=train_root, ann_path=train_ann, resize_shape=img_size)
    val_dataset = dataset.Countgwhd(img_path=val_root, ann_path=val_ann, resize_shape=img_size)
    test_dataset = dataset.Countgwhd(img_path=test_root, ann_path=test_ann, resize_shape=img_size)

    #创建网络模型
    model = Multi_Granularity(device=device)
    model.to(device)

    #创建损失函数
    loss_fn = nn.L1Loss(reduction="sum")
    loss_fn = loss_fn.to(device)

    #优化器
    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[800], gamma=0.1)

    torch.set_num_threads(workers)
    writer = SummaryWriter("log/"+ save_sumary)
    best = 10
    val_epoch = 0
    #训练
    for i in range(epoch):
        start1 = time.time()
        print("----------epoch: {}, lr: {}----------".format(i + 1, optimizer.param_groups[0]['lr']))
        log_info("----------epoch: {}, lr: {}----------".format(i + 1, optimizer.param_groups[0]['lr']))
        loss_one_epoch = train(train_dataset, model, loss_fn, optimizer, lr_scheduler, args.batch_size, device)
        writer.add_scalar("train_loss", loss_one_epoch, i)
        end1 = time.time()
        print("这轮所用时间为：{}min \n\n".format((end1-start1)/60))
        log_info("这轮所用时间为：{}min \n\n".format((end1 - start1) / 60))

        if (i+1) % 5 == 0 and i>=9 :
            val_epoch = val_epoch + 1
            start2 = time.time()
            print("----------开始验证----------")
            log_info("----------开始验证----------")
            prec = val(val_dataset, model, device)
            if prec < best:
                best = prec
                torch.save(model.state_dict(), save_file)
            end2 = time.time()
            print("测试所用时间为：{}min".format((end2-start2)/60))
            print("当前最好的mae为：{}".format(best))
            log_info("测试所用时间为：{}min".format((end2 - start2) / 60))
            log_info("当前最好的mae为：{}".format(best))
            writer.add_scalar("val_mae", prec, val_epoch)

    print("\n----------开始测试----------")
    test(test_dataset, model, device, best)
    writer.close()

def train(train_dataset, model, loss_fn, optimizer, lr_scheduler, batch_size, device):
    # 加载数据集
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True,
                                  num_workers=workers)
    model.train()
    loss_ave = 0
    print_freq = 20
    data_num = 0
    for data,filenames in train_dataloader:

        data_num = data_num + 1
        imgs, targets = data


        imgs = imgs.to(device)

        targets = targets.float().to(device)

        outputs = model(imgs)
        outputs = torch.reshape(outputs, [-1])

        loss = loss_fn(outputs, targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        loss_ave = loss_ave + loss.item()
        if data_num % print_freq == 0:
            print("---loss: {}---".format(loss.item()))
            log_info("---loss: {}---".format(loss.item()))
    print("----------本轮的平均loss为：{}  ----------\n".format(loss_ave / data_num))
    log_info("----------本轮的平均loss为：{}  ----------\n".format(loss_ave / data_num))

    lr_scheduler.step()
    return loss_ave / data_num

import torch
import math
from torch.utils.data import DataLoader

def val(val_dataset, model, device):
    # 加载数据集
    val_dataloader = DataLoader(val_dataset, batch_size=1)
    model.eval()

    mae = 0.0
    mse = 0.0
    i = 0
    total_params = 0
    trainable_params = 0

    for name, param in model.named_parameters():
        if param.requires_grad:
            trainable_params += param.numel()
        total_params += param.numel()

    print(f"Total parameters: {total_params}")
    print(f"Trainable parameters: {trainable_params}")
    log_info(f"Total parameters: {total_params}")
    log_info(f"Trainable parameters: {trainable_params}")

    for data,filename in val_dataloader:
        i = i + 1
        imgs, targets = data
        imgs = imgs.to(device)
        targets = targets.to(device)
        with torch.no_grad():
            output = model(imgs)
            count = torch.sum(output).item()

        gt_count = torch.sum(targets).item()
        mae += abs(gt_count - count)
        mse += abs(gt_count - count) * abs(gt_count - count)

        if i % 15 == 0:
            print("真实数量：{}     \t 预测数量：{}".format(filename,gt_count, count))
            log_info("图像名字：{}    真实数量：{}     \t 预测数量：{}".format( filename,gt_count, count))

    mae = mae * 1.0 / i
    mse = math.sqrt(mse / i)
    print("此次测试结果为：MAE：{}  \t MSE：{}".format(mae, mse))
    log_info("此次测试结果为：MAE：{}  \t MSE：{}".format(mae, mse))

    return mae


def test(test_dataset, model, device, best):
    #加载数据集
    if best < 10:
        model.load_state_dict(torch.load(save_file))
    else:
        pass
    val_dataloader = DataLoader(test_dataset, batch_size=1)
    model.eval()

    mae = 0.0
    mse = 0.0
    i = 0
    for data in val_dataloader:
        i = i + 1
        imgs, targets = data
        imgs = imgs.to(device)
        targets = targets.to(device)
        with torch.no_grad():
            output = model(imgs)
            count = torch.sum(output).item()

        gt_count = torch.sum(targets).item()
        mae += abs(gt_count - count)
        mse += abs(gt_count - count) * abs(gt_count - count)

        print("图像：{}               真实数量：{}     \t 预测数量：{}".format(filename,gt_count, count))

    mae = mae * 1.0 / i
    mse = math.sqrt(mse / i)
    print("此次测试结果为：MAE：{}  \t MSE：{}".format(mae, mse))


if __name__ == '__main__':
    main()







