diff --git a/src/get_nets.py b/src/get_nets.py index 67771ff..b85b210 100644 --- a/src/get_nets.py +++ b/src/get_nets.py @@ -52,7 +52,7 @@ def __init__(self): self.conv4_1 = nn.Conv2d(32, 2, 1, 1) self.conv4_2 = nn.Conv2d(32, 4, 1, 1) - weights = np.load('src/weights/pnet.npy')[()] + weights = np.load('src/weights/pnet.npy', allow_pickle=True)[()] for n, p in self.named_parameters(): p.data = torch.FloatTensor(weights[n]) @@ -97,7 +97,7 @@ def __init__(self): self.conv5_1 = nn.Linear(128, 2) self.conv5_2 = nn.Linear(128, 4) - weights = np.load('src/weights/rnet.npy')[()] + weights = np.load('src/weights/rnet.npy', allow_pickle=True)[()] for n, p in self.named_parameters(): p.data = torch.FloatTensor(weights[n]) @@ -148,7 +148,7 @@ def __init__(self): self.conv6_2 = nn.Linear(256, 4) self.conv6_3 = nn.Linear(256, 10) - weights = np.load('src/weights/onet.npy')[()] + weights = np.load('src/weights/onet.npy', allow_pickle=True)[()] for n, p in self.named_parameters(): p.data = torch.FloatTensor(weights[n])