Skip to content

Commit 32576e4

Browse files
committed
created functional model_manifest workflow
1 parent 17815d1 commit 32576e4

File tree

5 files changed

+138
-41
lines changed

5 files changed

+138
-41
lines changed

adk/ADK.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,24 @@
33
import os
44
import sys
55
from adk.io import create_exception, format_data, format_response
6+
from adk.manifest import ManifestData
67

78

89
class ADK(object):
9-
def __init__(self, apply_func, load_func=None):
10+
def __init__(self, apply_func, load_func=None, client=None, manifest_path="model_manifest.json.lock"):
1011
"""
1112
Creates the adk object
1213
:param apply_func: A required function that can have an arity of 1-2, depending on if loading occurs
1314
:param load_func: An optional supplier function used if load time events are required, has an arity of 0.
15+
:param client: A Algorithmia Client instance that might be user defined, and is used for interacting with a model manifest file; if defined.
1416
"""
1517
self.FIFO_PATH = "/tmp/algoout"
16-
self.manifest_path = "data_manifest.json"
1718
apply_args, _, _, _, _, _, _ = inspect.getfullargspec(apply_func)
1819
if load_func:
1920
load_args, _, _, _, _, _, _ = inspect.getfullargspec(load_func)
20-
if len(load_args) > 0:
21-
raise Exception("load function must not have parameters")
21+
if len(load_args) > 2:
22+
raise Exception("load function may either have no parameters, or one parameter providing the manifest "
23+
"state.")
2224
self.load_func = load_func
2325
else:
2426
self.load_func = None
@@ -28,10 +30,14 @@ def __init__(self, apply_func, load_func=None):
2830
self.is_local = not os.path.exists(self.FIFO_PATH)
2931
self.load_result = None
3032
self.loading_exception = None
33+
self.manifest = ManifestData(client, manifest_path)
3134

3235
def load(self):
3336
try:
34-
if self.load_func:
37+
if self.load_func and self.manifest.available():
38+
self.manifest.initialize()
39+
self.load_result = self.load_func(self.manifest)
40+
elif self.load_func:
3541
self.load_result = self.load_func()
3642
except Exception as e:
3743
self.loading_exception = e

adk/manifest.py

Lines changed: 64 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,73 @@
33
import hashlib
44

55

6-
def process_manifest(self):
7-
if os.path.exists(self.manifest_path):
8-
with open(self.manifest_path) as f:
9-
manifest_data = json.load(f)
10-
manifest = manifest_data
6+
class ManifestData(object):
7+
def __init__(self, client, model_manifest_path):
8+
self.manifest_lock_path = model_manifest_path
9+
self.manifest_data = get_manifest(self.manifest_lock_path)
10+
self.client = client
11+
self.models = {}
12+
13+
def available(self):
14+
if self.manifest_data:
15+
return True
16+
else:
17+
return False
18+
19+
def initialize(self):
20+
if self.client is None:
21+
raise Exception("Client was not defined, please define a Client when using Model Manifests.")
22+
for required_file in self.manifest_data['required_models']:
23+
name = required_file['name']
24+
if name in self.models:
25+
raise Exception("Duplicate 'name' detected. \n"
26+
+ name + " was found to be used by more than one data file, please rename.")
27+
self.models[name] = {}
28+
expected_hash = required_file['md5_checksum']
29+
with self.client.file(required_file['data_api_path']).getFile() as f:
30+
local_data_path = f.name
31+
real_hash = md5(local_data_path)
32+
if not real_hash != expected_hash and required_file['fail_on_tamper']:
33+
raise Exception("Model File Mismatch for " + name +
34+
"\nexpected hash: " + expected_hash + "\nreal hash: " + real_hash)
35+
else:
36+
self.models[name]["md5_checksum"] = real_hash
37+
self.models[name]['model_path'] = local_data_path
38+
39+
def get_model(self, model_name):
40+
if model_name in self.models:
41+
return self.models[model_name]['model_path']
42+
elif model_name in self.manifest_data['optional_files']:
43+
self.find_optional_model(model_name)
44+
return self.models[model_name]['model_path']
1145

46+
def find_optional_model(self, model_name):
1247

13-
def check_hash(model_path, expected_hash):
14-
real_hash = md5(model_path)
15-
if real_hash == expected_hash:
16-
return True
48+
found_models = [optional for optional in self.manifest_data['optional_models'] if
49+
optional['name'] == model_name]
50+
if len(found_models) == 0:
51+
raise Exception("model with name '" + model_name + "' not found in model manifest.")
52+
model_info = found_models[0]
53+
self.models[model_name] = {}
54+
expected_hash = model_info['md5_checksum']
55+
with self.client.file(model_info['data_api_path']).getFile() as f:
56+
local_data_path = f.name
57+
real_hash = md5(local_data_path)
58+
if not real_hash != expected_hash and model_info['fail_on_tamper']:
59+
raise Exception("Model File Mismatch for " + model_name +
60+
"\nexpected hash: " + expected_hash + "\nreal hash: " + real_hash)
61+
else:
62+
self.models[model_name]["md5_checksum"] = real_hash
63+
self.models[model_name]['model_path'] = local_data_path
64+
65+
66+
def get_manifest(manifest_path):
67+
if os.path.exists(manifest_path):
68+
with open(manifest_path) as f:
69+
manifest_data = json.load(f)
70+
return manifest_data
1771
else:
18-
return False
72+
return None
1973

2074

2175
def md5(fname):
Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,31 @@
11
{
22
"algorithm_name": "pytorch_image_classification",
3-
"model_versioning": "strict",
4-
"required_files" : [
5-
{ "name": "labels",
3+
"required_models": [
4+
{
5+
"name": "model_squeezenet",
6+
"data_api_path": "data://AlgorithmiaSE/image_cassification_demo/squeezenet1_1-f364aa15.pth",
7+
"fail_on_tamper": true,
8+
"metadata": {
9+
"origination_md5_checksum": "46a44d32d2c5c07f7f66324bef4c7266"
10+
}
11+
},
12+
{
13+
"name": "labels",
614
"data_api_path": "data://AlgorithmiaSE/image_cassification_demo/imagenet_class_index.json",
7-
"md5_checksum": "c2c37ea517e94d9795004a39431a14cb",
8-
"origination_md5_checksum": "46a44d32d2c5c07f7f66324bef4c7266"
15+
"fail_on_tamper": true,
16+
"metadata": {
17+
"origination_md5_checksum": "46a44d32d2c5c07f7f66324bef4c7266"
18+
}
919
}
1020
],
11-
"optional_files": [
21+
"optional_models": [
1222
{
13-
"name": "optional_data",
14-
"data_api_path": "data://AlgorithmiaSE/image_cassification_demo/imagenet_class_index.json",
15-
"md5_checksum": "c2c37ea517e94d9795004a39431a14cb",
16-
"origination_md5_checksum": "46a44d32d2c5c07f7f66324bef4c7266"
23+
"name": "vgg",
24+
"data_api_path": "data://AlgorithmiaSE/image_cassification_demo/vgg16-397923af.pth",
25+
"fail_on_tamper": false,
26+
"metadata": {
27+
"origination_md5_checksum": "46a44d32d2c5c07f7f66324bef4c7266"
28+
}
1729
}
1830
]
1931
}
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
{
2+
"algorithm_name": "pytorch_image_classification",
3+
"model_tamper_behavior": "strict",
4+
"lock_version": "f7aec94cde3ca6274c6556ebe58ad3c6",
5+
"timestamp": "1632753496"
6+
"required_models" : [
7+
{ "name": "labels",
8+
"data_api_path": "data://AlgorithmiaSE/image_cassification_demo/imagenet_class_index.json",
9+
"md5_checksum": "c2c37ea517e94d9795004a39431a14cb",
10+
"metadata": {
11+
"origination_md5_checksum": "46a44d32d2c5c07f7f66324bef4c7266"
12+
}
13+
}
14+
],
15+
"optional_models": [
16+
{
17+
"name": "optional_data",
18+
"data_api_path": "data://AlgorithmiaSE/image_cassification_demo/imagenet_class_index.json",
19+
"md5_checksum": "c2c37ea517e94d9795004a39431a14cb",
20+
"metadata": {
21+
"origination_md5_checksum": "46a44d32d2c5c07f7f66324bef4c7266"
22+
}
23+
}
24+
]
25+
}

examples/pytorch_image_classification/src/Algorithm.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,19 @@
55
import json
66
from torchvision import models, transforms
77

8-
def load_labels(label_path, client):
9-
local_path = client.file(label_path).getFile().name
10-
with open(local_path) as f:
8+
9+
client = Algorithmia.client()
10+
11+
def load_labels(label_path):
12+
with open(label_path) as f:
1113
labels = json.load(f)
1214
labels = [labels[str(k)][1] for k in range(len(labels))]
1315
return labels
1416

1517

16-
def load_model(model_paths, client):
18+
def load_model(model_path):
1719
model = models.squeezenet1_1()
18-
local_file = client.file(model_paths["filepath"]).getFile().name
19-
weights = torch.load(local_file)
20+
weights = torch.load(model_path)
2021
model.load_state_dict(weights)
2122
return model.float().eval()
2223

@@ -52,26 +53,25 @@ def infer_image(image_url, n, globals):
5253

5354
def load(manifest):
5455

55-
globals = {}
56-
client = Algorithmia.client()
57-
globals["SMID_ALGO"] = "algo://util/SmartImageDownloader/0.2.x"
58-
globals["model"] = load_model(manifest["model_squeezenet"], client)
59-
globals["labels"] = load_labels(manifest["labels"], client)
60-
return globals
56+
state = {}
57+
state["SMID_ALGO"] = "algo://util/SmartImageDownloader/0.2.x"
58+
state["model"] = load_model(manifest.get_model("model_squeezenet"))
59+
state["labels"] = load_labels(manifest.get_model("labels"))
60+
return state
6161

6262

63-
def apply(input, globals):
63+
def apply(input, state):
6464
if isinstance(input, dict):
6565
if "n" in input:
6666
n = input["n"]
6767
else:
6868
n = 3
6969
if "data" in input:
7070
if isinstance(input["data"], str):
71-
output = infer_image(input["data"], n, globals)
71+
output = infer_image(input["data"], n, state)
7272
elif isinstance(input["data"], list):
7373
for row in input["data"]:
74-
row["predictions"] = infer_image(row["image_url"], n, globals)
74+
row["predictions"] = infer_image(row["image_url"], n, state)
7575
output = input["data"]
7676
else:
7777
raise Exception("\"data\" must be a image url or a list of image urls (with labels)")
@@ -82,5 +82,5 @@ def apply(input, globals):
8282
raise Exception("input must be a json object")
8383

8484

85-
algorithm = ADK(apply_func=apply, load_func=load)
85+
algorithm = ADK(apply_func=apply, load_func=load, client=client)
8686
algorithm.init({"data": "https://i.imgur.com/bXdORXl.jpeg"})

0 commit comments

Comments
 (0)