@@ -129,18 +129,19 @@ from PIL import Image
129129import json
130130from torchvision import models, transforms
131131
132- def load_labels (label_path , client ):
133- local_path = client.file(label_path).getFile().name
134- with open (local_path) as f:
132+
133+ client = Algorithmia.client()
134+
135+ def load_labels (label_path ):
136+ with open (label_path) as f:
135137 labels = json.load(f)
136138 labels = [labels[str (k)][1 ] for k in range (len (labels))]
137139 return labels
138140
139141
140- def load_model (model_paths , client ):
142+ def load_model (model_path ):
141143 model = models.squeezenet1_1()
142- local_file = client.file(model_paths[" filepath" ]).getFile().name
143- weights = torch.load(local_file)
144+ weights = torch.load(model_path)
144145 model.load_state_dict(weights)
145146 return model.float().eval()
146147
@@ -176,26 +177,25 @@ def infer_image(image_url, n, globals):
176177
177178def load (manifest ):
178179
179- globals = {}
180- client = Algorithmia.client()
181- globals [" SMID_ALGO" ] = " algo://util/SmartImageDownloader/0.2.x"
182- globals [" model" ] = load_model(manifest[" squeezenet" ], client)
183- globals [" labels" ] = load_labels(manifest[" label_file" ], client)
184- return globals
180+ state = {}
181+ state[" SMID_ALGO" ] = " algo://util/SmartImageDownloader/0.2.x"
182+ state[" model" ] = load_model(manifest.get_model(" squeezenet" ))
183+ state[" labels" ] = load_labels(manifest.get_model(" labels" ))
184+ return state
185185
186186
187- def apply (input , globals ):
187+ def apply (input , state ):
188188 if isinstance (input , dict ):
189189 if " n" in input :
190190 n = input [" n" ]
191191 else :
192192 n = 3
193193 if " data" in input :
194194 if isinstance (input [" data" ], str ):
195- output = infer_image(input [" data" ], n, globals )
195+ output = infer_image(input [" data" ], n, state )
196196 elif isinstance (input [" data" ], list ):
197197 for row in input [" data" ]:
198- row[" predictions" ] = infer_image(row[" image_url" ], n, globals )
198+ row[" predictions" ] = infer_image(row[" image_url" ], n, state )
199199 output = input [" data" ]
200200 else :
201201 raise Exception (" \" data\" must be a image url or a list of image urls (with labels)" )
@@ -206,7 +206,7 @@ def apply(input, globals):
206206 raise Exception (" input must be a json object" )
207207
208208
209- algorithm = ADK(apply_func = apply, load_func = load)
209+ algorithm = ADK(apply_func = apply, load_func = load, client = client )
210210algorithm.init({" data" : " https://i.imgur.com/bXdORXl.jpeg" })
211211
212212```
0 commit comments