pytorch transformers ....
发布于 2022年 05月 19日 13:14
transformers的预训练模型下载到本地特定位置,默认是在~/.cache/huggingface/transformers
model = GPT2LMHeadModel.from_pretrained('gpt2', cache_dir="...")
想知道transformers的模型都是什么结构的,比如bert模型:
transformers/models/bert/__init__.py
这里可以看到导入了
from .modeling_bert import (
BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
BertForMaskedLM,
BertForMultipleChoice,
BertForNextSentencePrediction,
BertForPreTraining,
BertForQuestionAnswering,
BertForSequenceClassification,
BertForTokenClassification,
BertLayer,
BertLMHeadModel,
BertModel,
BertPreTrainedModel,
load_tf_weights_in_bert,
)
然后点进去就可以看了,可以看他们的forward函数等