From ce58f9a0a0af71907d97c842a12a5ecb6ca57276 Mon Sep 17 00:00:00 2001 From: f0ti Date: Thu, 19 Nov 2020 14:21:13 +0100 Subject: [PATCH] allow_pickle, for outdated numpy --- src/get_nets.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/get_nets.py b/src/get_nets.py index 67771ff..b85b210 100644 --- a/src/get_nets.py +++ b/src/get_nets.py @@ -52,7 +52,7 @@ def __init__(self): self.conv4_1 = nn.Conv2d(32, 2, 1, 1) self.conv4_2 = nn.Conv2d(32, 4, 1, 1) - weights = np.load('src/weights/pnet.npy')[()] + weights = np.load('src/weights/pnet.npy', allow_pickle=True)[()] for n, p in self.named_parameters(): p.data = torch.FloatTensor(weights[n]) @@ -97,7 +97,7 @@ def __init__(self): self.conv5_1 = nn.Linear(128, 2) self.conv5_2 = nn.Linear(128, 4) - weights = np.load('src/weights/rnet.npy')[()] + weights = np.load('src/weights/rnet.npy', allow_pickle=True)[()] for n, p in self.named_parameters(): p.data = torch.FloatTensor(weights[n]) @@ -148,7 +148,7 @@ def __init__(self): self.conv6_2 = nn.Linear(256, 4) self.conv6_3 = nn.Linear(256, 10) - weights = np.load('src/weights/onet.npy')[()] + weights = np.load('src/weights/onet.npy', allow_pickle=True)[()] for n, p in self.named_parameters(): p.data = torch.FloatTensor(weights[n])