pytorch transformers ....

发布于 2022年 05月 19日 13:14

transformers的预训练模型下载到本地特定位置,默认是在~/.cache/huggingface/transformers

model = GPT2LMHeadModel.from_pretrained('gpt2', cache_dir="...")

想知道transformers的模型都是什么结构的,比如bert模型:

transformers/models/bert/__init__.py

这里可以看到导入了

from .modeling_bert import (
            BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
            BertForMaskedLM,
            BertForMultipleChoice,
            BertForNextSentencePrediction,
            BertForPreTraining,
            BertForQuestionAnswering,
            BertForSequenceClassification,
            BertForTokenClassification,
            BertLayer,
            BertLMHeadModel,
            BertModel,
            BertPreTrainedModel,
            load_tf_weights_in_bert,
        )

然后点进去就可以看了,可以看他们的forward函数等

推荐文章