-
Notifications
You must be signed in to change notification settings - Fork 39
Description
Hi,
I encountered an assert issue when training RMI on very small (100 items) datasets for testing purposes:
thread '<unnamed>' panicked at 'start index was 100 but end index was 100', [...]/RMI/rmi_lib/src/train/two_layer.rs:27:5
Upon closer investigation, I think I have found an off-by-one error in the train_two_layer function implementation. I did not take a look at the context beyond the function, therefore take what I am about to say with a grain of salt:
- The value of
split_idxis calculated here:
RMI/rmi_lib/src/train/two_layer.rs
Line 132 in 5fdff45
| let split_idx = md_container.lower_bound_by(|x| { |
split_idxshould be in the interval[0, md_container.len())
RMI/rmi_lib/src/train/two_layer.rs
Line 139 in 5fdff45
| if split_idx > 0 && split_idx < md_container.len() { |
Now lets look at the case where split_idx == md_container.len() - 1, which is valid per [2.]:
- The else branch is taken, since
split_idx < md_container.len()
RMI/rmi_lib/src/train/two_layer.rs
Line 147 in 5fdff45
| let mut leaf_models = if split_idx >= md_container.len() { |
split_idx + 1(== md_container.len()) is passed tobuild_models_fromasstart_idx
|| build_models_from(&md_container, &top_model, layer2_model,
split_idx + 1, md_container.len(),
split_idx_target,
second_half_models)fn build_models_from<T: TrainingKey>(data: &RMITrainingData<T>,
top_model: &Box<dyn Model>,
model_type: &str,
start_idx: usize, end_idx: usize,
first_model_idx: usize,
num_models: usize) -> Vec<Box<dyn Model>>- Assert fails, since
md_container.len() > md_container.len()is false
assert!(end_idx > start_idx,
"start index was {} but end index was {}",
start_idx, end_idx
);An obvious fix would be to change the condition in [3.] to split_idx >= md_container.len() - 1, however I am not entirely certain whether that leads to issues in other contexts. I guess a similar issue would happen if similar_idx == 0, only for the first call. I changed the condition in my local version and re-ran the tests - it seems to work just fine:
Running test cache_fix_osm
Test cache_fix_osm finished.
Running test cache_fix_wiki
Test cache_fix_wiki finished.
Running test max_size_wiki
Test max_size_wiki finished.
Running test radix_model_wiki
Test radix_model_wiki finished.
Running test simple_model_osm
Test simple_model_osm finished.
Running test simple_model_wiki
Test simple_model_wiki finished.
============== TEST RESULTS ===============
python3 report.py
PASS cache_fix_osm
PASS cache_fix_wiki
PASS max_size_wiki
PASS radix_model_wiki
PASS simple_model_osm
PASS simple_model_wiki
I can open a pull request with that fix if you would like.