传知代码-基于图神经网络的知识追踪方法(论文复现)

news/2024/9/28 23:06:58 标签: 神经网络, 人工智能, 深度学习, 机器学习

代码以及视频讲解

本文所涉及所有资源均在传知代码平台可获取

1.论文概述

论文链接提出了一种基于图神经网络的知识追踪方法,称为基于图的知识追踪(GKT)。将知识结构构建为图,其中节点对应于概念,边对应于它们之间的关系,将知识追踪任务构建为图神经网络中的时间序列节点级分类问题。在两个开放数据集上的实证验证表明,方法可以更好地预测学生的表现,并且该模型比先前的方法具有更可解释的预测。
贡献如下:
(1)展示了知识追踪可以重新构想为图神经网络的应用。
(2)为了实现需要输入模型的图结构,在许多情况下并不明确的情况下,我们提出了各种方法,并使用实证验证进行了比较。
(3)证明了所提出的方法比先前的方法更准确和可解释的预测。

2.论文方法

下面是本文提出GKT的体系结构。
在这里插入图片描述

2.1 聚合

模型聚合了回答的概念及其相邻概念的隐藏状态和嵌入。这种聚合使用隐藏状态、表示正确和错误答案的输入向量 xt​,以及概念及其回答的嵌入矩阵Ex 和Ec 进行,
在这里插入图片描述

2.2 更新

接下来,模型根据聚集的特征和知识图结构更新隐藏状态。这一步骤确保模型融合了当前概念及其在知识图中的相邻节点的信息。
在这里插入图片描述

2.3 预测

最后,模型输出学生在下一时间步正确回答每个概念的预测概率
在这里插入图片描述

3. 实验

3.1 数据集

使用了学生数学练习日志的两个开放数据集:ASSISTments 2009-2010“skill-builder”由在线教育服务 ASSISTments1(以下称为“ASSISTments”)提供和 Bridge to Algebra 2006-2007 [19] 用于KDDCup 教育数据挖掘挑战赛(以下简称“KDDCup”)。在这两个数据集中,每个练习都分配了人类预定义的知识概念标签。
使用特定条件预处理每个数据集。对于ASSISTments,将同时回答的日志合二为一,随后提取与命名概念标签相关联的日志,最后提取与至少10次回答的概念标签相关联的日志。对于 KDDCup,将问题和步骤的组合视为一个答案,然后提取与命名且非哑元的概念标签相关联的日志,最后提取至少 10 次回答的概念标签相关联的日志。由于频繁同时出现的标签,将同时的回答日志组合成一组可以防止不公平的高预测性能。排除未命名或虚拟的概念标签可以消除噪音。用回答每个概念标签的次数对日志进行阈值处理,以确保有足够数量的日志来消除噪音。在使用上述条件对数据集进行预处理后,为 ASSISTments 数据集获得了 62, 955 个日志,由 1, 000 名学生和 101 项技能组成,并为 KDDCup 数据集获得了 98, 200 条日志,由 1, 000 名学生和 211 项技能组成。
在这里插入图片描述

3.2 实验步骤

Step1:处理数据集

在这里插入图片描述

Step2:进行训练

在这里插入图片描述

3.3 实验结果

在这里插入图片描述

4.核心代码

class GKT(KTM):
    def __init__(self, ku_num, graph, hidden_num, net_params: dict = None, loss_params=None):
        super(GKT, self).__init__()
        self.gkt_model = GKTNet(
            ku_num,
            graph,
            hidden_num,
            **(net_params if net_params is not None else {})
        )
        # self.gkt_model = GKTNet(ku_num, graph, hidden_num)
        self.loss_params = loss_params if loss_params is not None else {}

    def train(self, train_data, test_data=None, *, epoch: int, device="cpu", lr=0.001) -> ...:
        loss_function = SLMLoss(**self.loss_params)
        trainer = torch.optim.Adam(self.gkt_model.parameters(), lr)

        for e in range(epoch):
            losses = []
            for (question, data, data_mask, label, pick_index, label_mask) in tqdm(train_data, "Epoch %s" % e):
                # convert to device
                question: torch.Tensor = question.to(device)
                data: torch.Tensor = data.to(device)
                data_mask: torch.Tensor = data_mask.to(device)
                label: torch.Tensor = label.to(device)
                pick_index: torch.Tensor = pick_index.to(device)
                label_mask: torch.Tensor = label_mask.to(device)

                # real training
                predicted_response, _ = self.gkt_model(question, data, data_mask)

                loss = loss_function(predicted_response, pick_index, label, label_mask)

                # back propagation
                trainer.zero_grad()
                loss.backward()
                trainer.step()

                losses.append(loss.mean().item())
            print("[Epoch %d] SLMoss: %.6f" % (e, float(np.mean(losses))))

            if test_data is not None:
                auc, accuracy = self.eval(test_data)
                print("[Epoch %d] auc: %.6f, accuracy: %.6f" % (e, auc, accuracy))

    def eval(self, test_data, device="cpu") -> tuple:
        self.gkt_model.eval()
        y_true = []
        y_pred = []

        for (question, data, data_mask, label, pick_index, label_mask) in tqdm(test_data, "evaluating"):
            # convert to device
            question: torch.Tensor = question.to(device)
            data: torch.Tensor = data.to(device)
            data_mask: torch.Tensor = data_mask.to(device)
            label: torch.Tensor = label.to(device)
            pick_index: torch.Tensor = pick_index.to(device)
            label_mask: torch.Tensor = label_mask.to(device)

            # real evaluating
            output, _ = self.gkt_model(question, data, data_mask)
            output = output[:, :-1]
            output = pick(output, pick_index.to(output.device))
            pred = tensor2list(output)
            label = tensor2list(label)
            for i, length in enumerate(label_mask.numpy().tolist()):
                length = int(length)
                y_true.extend(label[i][:length])
                y_pred.extend(pred[i][:length])
        self.gkt_model.train()
        return roc_auc_score(y_true, y_pred), accuracy_score(y_true, np.array(y_pred) >= 0.5)

    def save(self, filepath) -> ...:
        torch.save(self.gkt_model.state_dict(), filepath)
        logging.info("save parameters to %s" % filepath)

    def load(self, filepath):
        self.gkt_model.load_state_dict(torch.load(filepath))
        logging.info("load parameters from %s" % filepath)

源码下载


http://www.niftyadmin.cn/n/5682028.html

相关文章

用通义灵码如何快速合理解决遗留代码问题?

本文首先介绍了遗留代码的概念,并对遗留代码进行了分类。针对不同类型的遗留代码,提供了相应的处理策略。此外,本文重点介绍了通义灵码在维护遗留代码过程中能提供哪些支持。 什么是遗留代码 与过时技术相关的代码: 与不再受支持的…

【linux进程】深度理解进程--什么是进程什么是pcb进程创建

目录 前言一,对PCB的理解二,CPU对进程列表的处理三,进程标识符:pid1. 查看系统进程1: ps axj2. 查看系统进程2: /proc 四,系统调用函数:getpid五,父进程和子进程的概念六,创建子进程--fork函数的使用1. 创建…

胤娲科技:AI界的超级充电宝——忆阻器如何让LLM告别电量焦虑

当AI遇上“记忆橡皮擦”,电量不再是问题! 嘿,朋友们,你们是否曾经因为手机电量不足而焦虑得像个无头苍蝇?想象一下,如果这种“电量焦虑”也蔓延到了AI界, 特别是那些聪明绝顶但“耗电如喝水”的…

map的键排序方法

1.对map中的key进行正序排序 Map<Integer, String> map Maps.newHashMap();// 原始map LinkedHashMap<Integer, String> sortedMap map.entrySet().stream().sorted(Map.Entry.comparingByKey()) // .collect(Collectors.toMap(Map.Entry::getKey…

【Python】Pythonic Data Structures and Algorithms:深入浅出数据结构与算法的 Python 实现

Pythonic Data Structures and Algorithms 是一个开源项目&#xff0c;汇集了各种经典数据结构和算法的 Python 实现。该项目旨在为开发者提供丰富的学习资源&#xff0c;帮助他们通过 Python 代码理解和掌握数据结构与算法的核心原理和应用。项目中的算法涵盖了排序、搜索、图…

Linux基础知识 + 常用命令

Linux基础 与Windows不同 1.Linux严格区分大小写 2.Linux中所有内容都已文件形式保存&#xff0c;包括硬件 3.Linux不靠拓展名区分文件类型 4.Windows下的程序不能直接在Linux中安装和运行 Linux管理 常用命令 ls 【选项】【文件或目录】 -a 全部 -l 详细 -h 人性化…

Java_集合_单列集合Collection

第一章.Collection接口 Collection<E> 集合名 new 实现类对象<E>() 常用方法: boolean add(E e) : 将给定的元素添加到当前集合中(我们一般调add时,不用boolean接收,因为add一定会成功) boolean addAll(Collection<? extends E> c) :将另一个集合元素添…

Linux系统中的重定向

目录 一、回顾重定向命令 1.输出重定向 > 2.追加重定向 >> 3.输入重定向 < 二、重定向原理 三、dup2函数 一、回顾重定向命令 1.输出重定向 > echo xxx > filename&#xff1a;将数据写入到文件中 文件不存在则创建文件再写入&#xff1b;文件存在则…