Skip to content

工具类 API 参考

本文档详细介绍 Torch-RecHub 中各个工具类的 API 接口和参数说明。

数据处理工具 (data.py)

数据集类

TorchDataset

  • 简介:PyTorch数据集的基础实现,用于处理特征和标签数据。
  • 参数
    • x (dict): 特征字典,键为特征名,值为特征数据
    • y (array): 标签数据

PredictDataset

  • 简介:用于预测阶段的数据集类,只包含特征数据。
  • 参数
    • x (dict): 特征字典,键为特征名,值为特征数据

MatchDataGenerator

  • 简介:召回任务的数据生成器,用于生成训练和测试数据加载器。
  • 主要方法
    • generate_dataloader(x_test_user, x_all_item, batch_size, num_workers=8): 生成训练、测试和物品数据加载器
    • 参数
      • x_test_user (dict): 测试用户特征
      • x_all_item (dict): 所有物品特征
      • batch_size (int): 批次大小
      • num_workers (int): 数据加载的工作进程数

DataGenerator

  • 简介:通用数据生成器,支持数据集的划分和加载。
  • 主要方法
    • generate_dataloader(x_val=None, y_val=None, x_test=None, y_test=None, split_ratio=None, batch_size=16, num_workers=0): 生成训练、验证和测试数据加载器
    • 参数
      • x_val, y_val: 验证集特征和标签
      • x_test, y_test: 测试集特征和标签
      • split_ratio (list): 训练集、验证集、测试集的划分比例
      • batch_size (int): 批次大小
      • num_workers (int): 数据加载的工作进程数

工具函数

get_auto_embedding_dim

  • 简介:根据类别数自动计算嵌入向量维度。
  • 参数
    • num_classes (int): 类别数量
  • 返回
    • int: 嵌入向量维度,计算公式:[6 * (num_classes)^(1/4)]

get_loss_func

  • 简介:获取损失函数。
  • 参数
    • task_type (str): 任务类型,"classification"或"regression"
  • 返回
    • torch.nn.Module: 对应的损失函数

get_metric_func

  • 简介:获取评估指标函数。
  • 参数
    • task_type (str): 任务类型,"classification"或"regression"
  • 返回
    • function: 对应的评估指标函数

generate_seq_feature

  • 简介:生成序列特征和负样本。
  • 参数
    • data (pd.DataFrame): 原始数据
    • user_col (str): 用户ID列名
    • item_col (str): 物品ID列名
    • time_col (str): 时间戳列名
    • item_attribute_cols (list): 需要生成序列特征的物品属性列
    • min_item (int): 用户最少交互物品数
    • shuffle (bool): 是否打乱数据
    • max_len (int): 序列最大长度

召回工具 (match.py)

数据处理函数

gen_model_input

  • 简介:合并用户和物品特征,处理序列特征。
  • 参数
    • df (pd.DataFrame): 带有历史序列特征的数据
    • user_profile (pd.DataFrame): 用户特征数据
    • user_col (str): 用户列名
    • item_profile (pd.DataFrame): 物品特征数据
    • item_col (str): 物品列名
    • seq_max_len (int): 序列最大长度
    • padding (str): 填充方式,'pre'或'post'
    • truncating (str): 截断方式,'pre'或'post'

negative_sample

  • 简介:召回模型的负采样方法。
  • 参数
    • items_cnt_order (dict): 物品计数字典,按计数降序排序
    • ratio (int): 负样本比例
    • method_id (int): 采样方法ID
      • 0: 随机采样
      • 1: Word2Vec式流行度采样
      • 2: 对数流行度采样
      • 3: 腾讯RALM采样

向量检索类

Annoy

  • 简介:基于Annoy的向量召回工具。
  • 参数
    • metric (str): 距离度量方式
    • n_trees (int): 树的数量
    • search_k (int): 搜索参数
  • 主要方法
    • fit(X): 构建索引
    • query(v, n): 查询最近邻

Milvus

  • 简介:基于Milvus的向量召回工具。
  • 参数
    • dim (int): 向量维度
    • host (str): Milvus服务器地址
    • port (str): Milvus服务器端口
  • 主要方法
    • fit(X): 构建索引
    • query(v, n): 查询最近邻

多任务学习工具 (mtl.py)

工具函数

shared_task_layers

  • 简介:获取多任务模型中的共享层和任务特定层参数。
  • 参数
    • model (torch.nn.Module): 多任务模型,支持MMOE、SharedBottom、PLE、AITM
  • 返回
    • list: 共享层参数列表
    • list: 任务特定层参数列表

优化器类

MetaBalance

  • 简介:MetaBalance优化器,用于平衡多任务学习中各任务的梯度。
  • 参数
    • parameters (list): 模型参数
    • relax_factor (float): 梯度缩放的松弛因子,默认0.7
    • beta (float): 移动平均系数,默认0.9
  • 主要方法
    • step(losses): 执行优化步骤,更新参数

梯度处理函数

gradnorm

  • 简介:实现GradNorm算法,用于动态调整多任务学习中的任务权重。
  • 参数
    • loss_list (list): 各任务的损失列表
    • loss_weight (list): 任务权重列表
    • share_layer (torch.nn.Parameter): 共享层参数
    • initial_task_loss (list): 初始任务损失列表
    • alpha (float): GradNorm算法的超参数