pytorch 模型搭建步骤

具体的步骤

  1. 首先对数据进行处理(这两个应该放在util中)
  2. 构建dataset
    --Init-- :初始化参数
    --Len--:数据长度
    --Getitem--:根据索引取数据(索引和len有关)返回对应的数据
  1. 构建dataloader
    生成的是一个迭代器,用于小批量运算
    自定义collate_fn函数可以设置dataloader返回的数据
    train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True,collate_fn=collate_fn)
    这里可以将训练集和测试机合成的整个数据集放入 设置sampler来划分训练集和测试集,设置了sampler 那么shuffle自动失效
indices = list(range(dataset_size))
split = int(np.floor(validation_split * dataset_size))
train_indices, val_indices = indices[split:], indices[:split]
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)
training_generator = data.DataLoader(train_dataset, **params,
                                               sampler=train_sampler)
val_generator = data.DataLoader(train_dataset, **params,
                                                    sampler=valid_sampler)

设置种子

def seed_torch(seed=2):
    seed = int(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic =True

运行部分:

这里小数据可以使用交叉验证的方式

分一部分训练模式:
训练

print('Plan to train {} epoches \n'.format(epochs))

    for epoch in range(epochs):

        # mini-batch for training
        train_loss_list = []
        train_acc_list = []
        model.train()
        for step, (input_nodes, seeds, blocks) in enumerate(train_dataloader):
            # forward
            batch_inputs, batch_labels = load_subtensor(node_feat, labels, seeds, input_nodes, device_id)
            blocks = [block.to(device_id) for block in blocks]
            # metric and loss
            train_batch_logits = model(blocks, batch_inputs)
            batch_labels=th.tensor(batch_labels,dtype=th.long)
            train_loss = loss_fn(train_batch_logits, batch_labels)

            # backward
            optimizer.zero_grad()
            train_loss.backward()
            optimizer.step()

            train_loss_list.append(train_loss.cpu().detach().numpy())
            tr_batch_pred = th.sum(th.argmax(train_batch_logits, dim=1) == batch_labels) / th.tensor(
                batch_labels.shape[0])

            if step % 10 == 0:
                print('In epoch:{:03d}|batch:{:04d}, train_loss:{:4f}, train_acc:{:.4f}'.format(epoch,
                                                                                                step,
                                                                                                np.mean(
                                                                                                    train_loss_list),
                                                                                                tr_batch_pred.detach()))

        # mini-batch for validation
        val_loss_list = []
        val_acc_list = []
        model.eval()

测试
model.eval() # 测试模式,不会更新参数 不使用dropout等
with torch.no_grad(): # 不计算梯度加快计算

交叉验证版本:

# 根据id选择数据
class get_Dataset():
    def __init__(self,idx,all_data): # 设置初始信息
        self.seg_data=all_data[idx]
    def __len__(self): # 返回长度
        return len(self.seg_data)
    def __getitem__(self, item): # 根据item返回数据
        return self.seg_data[item]



skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=random_seed)
for fold, (train_index, valid_index) in enumerate(skf.split(all_label, all_label)):
  training_data=get_dataset(train_index,train_data)
  train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
  validing_data=get_dataset(valid_index,train_data)
  valid_dataloader = DataLoader(validing_data, batch_size=64, shuffle=True)

在GPU下运行
网络模型
数据(只有数据需要赋值)
损失函数

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device )  #这就放到GPU上运行了
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 199,711评论 5 468
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 83,932评论 2 376
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 146,770评论 0 330
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 53,799评论 1 271
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 62,697评论 5 359
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 48,069评论 1 276
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 37,535评论 3 390
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,200评论 0 254
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 40,353评论 1 294
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,290评论 2 317
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 37,331评论 1 329
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,020评论 3 315
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 38,610评论 3 303
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 29,694评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 30,927评论 1 255
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 42,330评论 2 346
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 41,904评论 2 341

推荐阅读更多精彩内容

  • 1. TENSORS Tensors是一种特殊的数据结构,非常类似于数组和矩阵。在PyTorch中,我们使用Ten...
    龙小治阅读 318评论 0 1
  • 顺着文档理解代码,加备注 import os, sys, glob, shutil, json os.enviro...
    王晓丹阅读 348评论 0 0
  • transformers是huggingface提供的预训练模型库,可以轻松调用API来得到你的词向量。trans...
    晓柒NLP与药物设计阅读 7,029评论 0 10
  • PyTorch,tensorflow.通过学习可以快速掌握的两个机器学习库对应的内容.但是这两个库非常强大,下次换...
    snipon阅读 248评论 0 0
  • 本章通过介绍构建神经网络所涉及的基本思想,如激活函数,损失函数,优化器和监督训练设置,为后面的章节打基础。我们从一...
    readilen阅读 2,691评论 2 12