Skip to content

Commit 3b20ccf

Browse files
committed
added tests and fixed some parts of the manifest system
1 parent 32576e4 commit 3b20ccf

File tree

6 files changed

+132
-4
lines changed

6 files changed

+132
-4
lines changed

adk/manifest.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def initialize(self):
2929
with self.client.file(required_file['data_api_path']).getFile() as f:
3030
local_data_path = f.name
3131
real_hash = md5(local_data_path)
32-
if not real_hash != expected_hash and required_file['fail_on_tamper']:
32+
if real_hash != expected_hash and required_file['fail_on_tamper']:
3333
raise Exception("Model File Mismatch for " + name +
3434
"\nexpected hash: " + expected_hash + "\nreal hash: " + real_hash)
3535
else:
@@ -39,9 +39,12 @@ def initialize(self):
3939
def get_model(self, model_name):
4040
if model_name in self.models:
4141
return self.models[model_name]['model_path']
42-
elif model_name in self.manifest_data['optional_files']:
42+
elif len([optional for optional in self.manifest_data['optional_models'] if
43+
optional['name'] == model_name]) > 0:
4344
self.find_optional_model(model_name)
4445
return self.models[model_name]['model_path']
46+
else:
47+
raise Exception("model name " + model_name + " not found in manifest")
4548

4649
def find_optional_model(self, model_name):
4750

@@ -55,7 +58,7 @@ def find_optional_model(self, model_name):
5558
with self.client.file(model_info['data_api_path']).getFile() as f:
5659
local_data_path = f.name
5760
real_hash = md5(local_data_path)
58-
if not real_hash != expected_hash and model_info['fail_on_tamper']:
61+
if real_hash != expected_hash and model_info['fail_on_tamper']:
5962
raise Exception("Model File Mismatch for " + model_name +
6063
"\nexpected hash: " + expected_hash + "\nreal hash: " + real_hash)
6164
else:
@@ -77,4 +80,4 @@ def md5(fname):
7780
with open(fname, "rb") as f:
7881
for chunk in iter(lambda: f.read(4096), b""):
7982
hash_md5.update(chunk)
80-
return hash_md5.hexdigest()
83+
return str(hash_md5.hexdigest())

examples/pytorch_image_classification/model_manifest.json.lock

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
{ "name": "labels",
88
"data_api_path": "data://AlgorithmiaSE/image_cassification_demo/imagenet_class_index.json",
99
"md5_checksum": "c2c37ea517e94d9795004a39431a14cb",
10+
"fail_on_tamper": true,
1011
"metadata": {
1112
"origination_md5_checksum": "46a44d32d2c5c07f7f66324bef4c7266"
1213
}
@@ -17,6 +18,7 @@
1718
"name": "optional_data",
1819
"data_api_path": "data://AlgorithmiaSE/image_cassification_demo/imagenet_class_index.json",
1920
"md5_checksum": "c2c37ea517e94d9795004a39431a14cb",
21+
"fail_on_tamper": true,
2022
"metadata": {
2123
"origination_md5_checksum": "46a44d32d2c5c07f7f66324bef4c7266"
2224
}

tests/adk_algorithms.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,33 @@
11
import Algorithmia
22
import base64
3+
import os
4+
35

46
# -- Apply functions --- #
57
def apply_basic(input):
68
return "hello " + input
79

10+
811
def apply_binary(input):
912
if isinstance(input, bytes):
1013
input = input.decode('utf8')
1114
return bytes("hello " + input, encoding='utf8')
1215

16+
1317
def apply_input_or_context(input, globals=None):
1418
if isinstance(globals, dict):
1519
return globals
1620
else:
1721
return "hello " + input
1822

1923

24+
def apply_successful_manifest_parsing(input, result):
25+
if result:
26+
return "all model files were successfully loaded"
27+
else:
28+
return "model files were not loaded correctly"
29+
30+
2031
# -- Loading functions --- #
2132
def loading_text():
2233
context = dict()
@@ -34,3 +45,14 @@ def loading_file_from_algorithmia():
3445
context['data_url'] = 'data://demo/collection/somefile.json'
3546
context['data'] = client.file(context['data_url']).getJson()
3647
return context
48+
49+
50+
def loading_with_manifest(manifest):
51+
squeezenet_path = manifest.get_model("squeezenet")
52+
labels_path = manifest.get_model("labels")
53+
# optional model
54+
mobilenet_path = manifest.get_model("mobilenet")
55+
if os.path.exists(squeezenet_path) and os.path.exists(labels_path) and os.path.exists(mobilenet_path):
56+
return True
57+
else:
58+
return False

tests/bad_model_manifest.json.lock

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
{
2+
"algorithm_name": "test_algorithm",
3+
"timestamp": "1632770803",
4+
"required_models" : [
5+
{ "name": "squeezenet",
6+
"data_api_path": "data://AlgorithmiaSE/image_cassification_demo/squeezenet1_1-f364aa15.pth",
7+
"md5_checksum": "f20b50b44fdef367a225d41f747a0963",
8+
"fail_on_tamper": true,
9+
"metadata": {
10+
"origination_md5_checksum": "46a44d32d2c5c07f7f66324bef4c7266"
11+
}
12+
},
13+
{
14+
"name": "labels",
15+
"data_api_path": "data://AlgorithmiaSE/image_cassification_demo/imagenet_class_index.json",
16+
"md5_checksum": "c2c37ea517e94d9795004a39431a14cb",
17+
"fail_on_tamper": true,
18+
"metadata": {
19+
"origination_md5_checksum": "46a44d32d2c5c07f7f66324bef4c7266"
20+
}
21+
}
22+
],
23+
"optional_models": [
24+
{
25+
"name": "mobilenet",
26+
"data_api_path": "data://AlgorithmiaSE/image_cassification_demo/mobilenet_v2-b0353104.pth",
27+
"md5_checksum": "c2c37ea517e94d9795004a39431a14cb",
28+
"fail_on_tamper": false,
29+
"metadata": {
30+
"origination_md5_checksum": "46a44d32d2c5c07f7f66324bef4c7266"
31+
}
32+
}
33+
]
34+
}
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
{
2+
"algorithm_name": "test_algorithm",
3+
"timestamp": "1632770803",
4+
"required_models" : [
5+
{ "name": "squeezenet",
6+
"data_api_path": "data://AlgorithmiaSE/image_cassification_demo/squeezenet1_1-f364aa15.pth",
7+
"md5_checksum": "46a44d32d2c5c07f7f66324bef4c7266",
8+
"fail_on_tamper": true,
9+
"metadata": {
10+
"origination_md5_checksum": "46a44d32d2c5c07f7f66324bef4c7266"
11+
}
12+
},
13+
{
14+
"name": "labels",
15+
"data_api_path": "data://AlgorithmiaSE/image_cassification_demo/imagenet_class_index.json",
16+
"md5_checksum": "c2c37ea517e94d9795004a39431a14cb",
17+
"fail_on_tamper": true,
18+
"metadata": {
19+
"origination_md5_checksum": "46a44d32d2c5c07f7f66324bef4c7266"
20+
}
21+
}
22+
],
23+
"optional_models": [
24+
{
25+
"name": "mobilenet",
26+
"data_api_path": "data://AlgorithmiaSE/image_cassification_demo/mobilenet_v2-b0353104.pth",
27+
"md5_checksum": "f20b50b44fdef367a225d41f747a0963",
28+
"fail_on_tamper": false,
29+
"metadata": {
30+
"origination_md5_checksum": "46a44d32d2c5c07f7f66324bef4c7266"
31+
}
32+
}
33+
]
34+
}

tests/test_adk_local.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,13 @@ def execute_example(self, input, apply, load=lambda: None):
2020
algo.init(input, pprint=lambda x: output.append(x))
2121
return output[0]
2222

23+
def execute_manifest_example(self, input, apply, load, manifest_path="good_model_manifest.json.lock"):
24+
client = Algorithmia.client()
25+
algo = ADK(apply, load, manifest_path=manifest_path, client=client)
26+
output = []
27+
algo.init(input, pprint=lambda x: output.append(x))
28+
return output[0]
29+
2330
def execute_without_load(self, input, apply):
2431
algo = ADK(apply)
2532
output = []
@@ -110,6 +117,32 @@ def test_binary_data(self):
110117
actual_output = json.loads(self.execute_without_load(input, apply_binary))
111118
self.assertEqual(expected_output, actual_output)
112119

120+
def test_manifest_file_success(self):
121+
input = "Algorithmia"
122+
expected_output = {'metadata':
123+
{
124+
'content_type': 'text'
125+
},
126+
'result': "all model files were successfully loaded"
127+
}
128+
actual_output = json.loads(self.execute_manifest_example(input, apply_successful_manifest_parsing,
129+
loading_with_manifest,
130+
manifest_path="tests/good_model_manifest.json.lock"))
131+
self.assertEqual(expected_output, actual_output)
132+
133+
def test_manifest_file_tampered(self):
134+
input = "Algorithmia"
135+
expected_output = {"error": {"error_type": "LoadingError",
136+
"message": "Model File Mismatch for squeezenet\n"
137+
"expected hash: f20b50b44fdef367a225d41f747a0963\n"
138+
"real hash: 46a44d32d2c5c07f7f66324bef4c7266",
139+
"stacktrace": "NoneType: None\n"}}
140+
141+
actual_output = json.loads(self.execute_manifest_example(input, apply_successful_manifest_parsing,
142+
loading_with_manifest,
143+
manifest_path="tests/bad_model_manifest.json.lock"))
144+
self.assertEqual(expected_output, actual_output)
145+
113146

114147
def run_test():
115148
unittest.main()

0 commit comments

Comments
 (0)