我正在尝试使用 from torchvision.models.feature_extraction import create_feature_extractor
.
我尝试使用的模型来自 vit_pytorch(链接:https://github.com/lucidrains/vit-pytorch)。我面临的问题是,当我从这个库创建模型时:
from vit_pytorch import ViT
from torchvision.models.feature_extraction import create_feature_extractor
model = ViT(image_size=28,
patch_size=7,
num_classes=10,
dim=16,
depth=6,
heads=16,
mlp_dim=256,
dropout=0.1,
emb_dropout=0.1,
channels=1)
random_layer_name = 'transformer.layers.1.1.fn.net.4'
feature_extractor = create_feature_extractor(model,
return_nodes=random_layer_name)
尝试在此模型上使用 create_feature_extractor()
时,我总是会收到此错误:
RuntimeError Traceback (most recent call last)
Cell In[17], line 2
1 # torch.fx.wrap('len')
----> 2 feature_extractor = create_feature_extractor(model,
3 return_nodes=['transformer.layers.1.1.fn.net.4'])
File ~/Mokslas/AI/venv/lib/python3.10/site-packages/torchvision/models/feature_extraction.py:485, in create_feature_extractor(model, return_nodes, train_return_nodes, eval_return_nodes, tracer_kwargs, suppress_diff_warning)
483 # Instantiate our NodePathTracer and use that to trace the model
484 tracer = NodePathTracer(**tracer_kwargs)
--> 485 graph = tracer.trace(model)
487 name = model.__class__.__name__ if isinstance(model, nn.Module) else model.__name__
488 graph_module = fx.GraphModule(tracer.root, graph, name)
File ~/Mokslas/AI/venv/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py:756, in Tracer.trace(self, root, concrete_args)
749 for module in self._autowrap_search:
750 _autowrap_check(
751 patcher, module.__dict__, self._autowrap_function_ids
752 )
753 self.create_node(
754 "output",
755 "output",
--> 756 (self.create_arg(fn(*args)),),
757 {},
758 type_expr=fn.__annotations__.get("return", None),
759 )
761 self.submodule_paths = None
762 finally:
File ~/Mokslas/AI/venv/lib/python3.10/site-packages/vit_pytorch/vit.py:115, in ViT.forward(self, img)
114 def forward(self, img):
--> 115 x = self.to_patch_embedding(img)
116 b, n, _ = x.shape
118 cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
File ~/Mokslas/AI/venv/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py:734, in Tracer.trace.<locals>.module_call_wrapper(mod, *args, **kwargs)
727 return _orig_module_call(mod, *args, **kwargs)
729 _autowrap_check(
730 patcher,
731 getattr(getattr(mod, "forward", mod), "__globals__", {}),
732 self._autowrap_function_ids,
733 )
--> 734 return self.call_module(mod, forward, args, kwargs)
File ~/Mokslas/AI/venv/lib/python3.10/site-packages/torchvision/models/feature_extraction.py:83, in NodePathTracer.call_module(self, m, forward, args, kwargs)
...
--> 396 raise RuntimeError("'len' is not supported in symbolic tracing by default. If you want "
397 "this call to be recorded, please call torch.fx.wrap('len') at "
398 "module scope")
RuntimeError: 'len' is not supported in symbolic tracing by default. If you want this call to be recorded, please call torch.fx.wrap('len') at module scope
无论我从那个库中选择哪个模型,或者我选择输出哪一层或哪一层,我总是得到同样的错误。
我尝试添加 torch.fx.wrap('len')
但同样的问题仍然存在。我知道我可以尝试使用钩子方法来解决它,但是有没有办法解决这个问题,这样我仍然可以使用 create_feature_extractor()
功能?