diff --git a/iotdb-core/ainode/iotdb/ainode/core/config.py b/iotdb-core/ainode/iotdb/ainode/core/config.py index afcf0683d7d04..04ec3ee68c16a 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/config.py +++ b/iotdb-core/ainode/iotdb/ainode/core/config.py @@ -31,6 +31,7 @@ AINODE_CONF_FILE_NAME, AINODE_CONF_GIT_FILE_NAME, AINODE_CONF_POM_FILE_NAME, + AINODE_FINETUNE_MODELS_DIR, AINODE_INFERENCE_BATCH_INTERVAL_IN_MS, AINODE_INFERENCE_EXTRA_MEMORY_RATIO, AINODE_INFERENCE_MAX_PREDICT_LENGTH, @@ -44,6 +45,7 @@ AINODE_SYSTEM_FILE_NAME, AINODE_TARGET_CONFIG_NODE_LIST, AINODE_THRIFT_COMPRESSION_ENABLED, + AINODE_USER_DEFINED_MODELS_DIR, AINODE_VERSION_INFO, ) from iotdb.ainode.core.exception import BadNodeUrlError @@ -96,6 +98,8 @@ def __init__(self): # Directory to save models self._ain_models_dir = AINODE_MODELS_DIR self._ain_builtin_models_dir = AINODE_BUILTIN_MODELS_DIR + self._ain_finetune_models_dir = AINODE_FINETUNE_MODELS_DIR + self._ain_user_defined_models_dir = AINODE_USER_DEFINED_MODELS_DIR self._ain_system_dir = AINODE_SYSTEM_DIR # Whether to enable compression for thrift @@ -210,6 +214,18 @@ def get_ain_builtin_models_dir(self) -> str: def set_ain_builtin_models_dir(self, ain_builtin_models_dir: str) -> None: self._ain_builtin_models_dir = ain_builtin_models_dir + def get_ain_finetune_models_dir(self) -> str: + return self._ain_finetune_models_dir + + def set_ain_finetune_models_dir(self, ain_finetune_models_dir: str) -> None: + self._ain_finetune_models_dir = ain_finetune_models_dir + + def get_ain_user_defined_models_dir(self) -> str: + return self._ain_user_defined_models_dir + + def set_ain_user_defined_models_dir(self, ain_user_defined_models_dir: str) -> None: + self._ain_user_defined_models_dir = ain_user_defined_models_dir + def get_ain_system_dir(self) -> str: return self._ain_system_dir diff --git a/iotdb-core/ainode/iotdb/ainode/core/constant.py b/iotdb-core/ainode/iotdb/ainode/core/constant.py index b9923d3e3ee7e..791aa2ac5429d 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/constant.py +++ b/iotdb-core/ainode/iotdb/ainode/core/constant.py @@ -20,7 +20,6 @@ from enum import Enum from typing import List -from iotdb.ainode.core.model.model_enums import BuiltInModelType from iotdb.thrift.common.ttypes import TEndPoint IOTDB_AINODE_HOME = os.getenv("IOTDB_AINODE_HOME", "") @@ -52,8 +51,8 @@ AINODE_INFERENCE_BATCH_INTERVAL_IN_MS = 15 AINODE_INFERENCE_MAX_PREDICT_LENGTH = 2880 AINODE_INFERENCE_MODEL_MEM_USAGE_MAP = { - BuiltInModelType.SUNDIAL.value: 1036 * 1024**2, # 1036 MiB - BuiltInModelType.TIMER_XL.value: 856 * 1024**2, # 856 MiB + "sundial": 1036 * 1024**2, # 1036 MiB + "timerxl": 856 * 1024**2, # 856 MiB } # the memory usage of each model in bytes AINODE_INFERENCE_MEMORY_USAGE_RATIO = 0.4 # the device space allocated for inference AINODE_INFERENCE_EXTRA_MEMORY_RATIO = ( @@ -63,8 +62,15 @@ # AINode folder structure AINODE_MODELS_DIR = os.path.join(IOTDB_AINODE_HOME, "data/ainode/models") AINODE_BUILTIN_MODELS_DIR = os.path.join( - IOTDB_AINODE_HOME, "data/ainode/models/weights" + IOTDB_AINODE_HOME, "data/ainode/models/builtin" ) # For built-in models, we only need to store their weights and config. +AINODE_FINETUNE_MODELS_DIR = os.path.join( + IOTDB_AINODE_HOME, "data/ainode/models/finetune" +) +AINODE_USER_DEFINED_MODELS_DIR = os.path.join( + IOTDB_AINODE_HOME, "data/ainode/models/user_defined" +) +AINODE_CACHE_DIR = os.path.expanduser("~/.cache/ainode") AINODE_SYSTEM_DIR = os.path.join(IOTDB_AINODE_HOME, "data/ainode/system") AINODE_LOG_DIR = os.path.join(IOTDB_AINODE_HOME, "logs") @@ -141,132 +147,8 @@ def name(self): return self.value -class ForecastModelType(Enum): - DLINEAR = "dlinear" - DLINEAR_INDIVIDUAL = "dlinear_individual" - NBEATS = "nbeats" - - @classmethod - def values(cls) -> List[str]: - values = [] - for item in list(cls): - values.append(item.value) - return values - - class ModelInputName(Enum): DATA_X = "data_x" TIME_STAMP_X = "time_stamp_x" TIME_STAMP_Y = "time_stamp_y" DEC_INP = "dec_inp" - - -class AttributeName(Enum): - # forecast Attribute - PREDICT_LENGTH = "predict_length" - - # NaiveForecaster - STRATEGY = "strategy" - SP = "sp" - - # STLForecaster - # SP = 'sp' - SEASONAL = "seasonal" - SEASONAL_DEG = "seasonal_deg" - TREND_DEG = "trend_deg" - LOW_PASS_DEG = "low_pass_deg" - SEASONAL_JUMP = "seasonal_jump" - TREND_JUMP = "trend_jump" - LOSS_PASS_JUMP = "low_pass_jump" - - # ExponentialSmoothing - DAMPED_TREND = "damped_trend" - INITIALIZATION_METHOD = "initialization_method" - OPTIMIZED = "optimized" - REMOVE_BIAS = "remove_bias" - USE_BRUTE = "use_brute" - - # Arima - ORDER = "order" - SEASONAL_ORDER = "seasonal_order" - METHOD = "method" - MAXITER = "maxiter" - SUPPRESS_WARNINGS = "suppress_warnings" - OUT_OF_SAMPLE_SIZE = "out_of_sample_size" - SCORING = "scoring" - WITH_INTERCEPT = "with_intercept" - TIME_VARYING_REGRESSION = "time_varying_regression" - ENFORCE_STATIONARITY = "enforce_stationarity" - ENFORCE_INVERTIBILITY = "enforce_invertibility" - SIMPLE_DIFFERENCING = "simple_differencing" - MEASUREMENT_ERROR = "measurement_error" - MLE_REGRESSION = "mle_regression" - HAMILTON_REPRESENTATION = "hamilton_representation" - CONCENTRATE_SCALE = "concentrate_scale" - - # GAUSSIAN_HMM - N_COMPONENTS = "n_components" - COVARIANCE_TYPE = "covariance_type" - MIN_COVAR = "min_covar" - STARTPROB_PRIOR = "startprob_prior" - TRANSMAT_PRIOR = "transmat_prior" - MEANS_PRIOR = "means_prior" - MEANS_WEIGHT = "means_weight" - COVARS_PRIOR = "covars_prior" - COVARS_WEIGHT = "covars_weight" - ALGORITHM = "algorithm" - N_ITER = "n_iter" - TOL = "tol" - PARAMS = "params" - INIT_PARAMS = "init_params" - IMPLEMENTATION = "implementation" - - # GMMHMM - # N_COMPONENTS = "n_components" - N_MIX = "n_mix" - # MIN_COVAR = "min_covar" - # STARTPROB_PRIOR = "startprob_prior" - # TRANSMAT_PRIOR = "transmat_prior" - WEIGHTS_PRIOR = "weights_prior" - - # MEANS_PRIOR = "means_prior" - # MEANS_WEIGHT = "means_weight" - # ALGORITHM = "algorithm" - # COVARIANCE_TYPE = "covariance_type" - # N_ITER = "n_iter" - # TOL = "tol" - # INIT_PARAMS = "init_params" - # PARAMS = "params" - # IMPLEMENTATION = "implementation" - - # STRAY - ALPHA = "alpha" - K = "k" - KNN_ALGORITHM = "knn_algorithm" - P = "p" - SIZE_THRESHOLD = "size_threshold" - OUTLIER_TAIL = "outlier_tail" - - # timerxl - INPUT_TOKEN_LEN = "input_token_len" - HIDDEN_SIZE = "hidden_size" - INTERMEDIATE_SIZE = "intermediate_size" - OUTPUT_TOKEN_LENS = "output_token_lens" - NUM_HIDDEN_LAYERS = "num_hidden_layers" - NUM_ATTENTION_HEADS = "num_attention_heads" - HIDDEN_ACT = "hidden_act" - USE_CACHE = "use_cache" - ROPE_THETA = "rope_theta" - ATTENTION_DROPOUT = "attention_dropout" - INITIALIZER_RANGE = "initializer_range" - MAX_POSITION_EMBEDDINGS = "max_position_embeddings" - CKPT_PATH = "ckpt_path" - - # sundial - DROPOUT_RATE = "dropout_rate" - FLOW_LOSS_DEPTH = "flow_loss_depth" - NUM_SAMPLING_STEPS = "num_sampling_steps" - DIFFUSION_BATCH_MUL = "diffusion_batch_mul" - - def name(self) -> str: - return self.value diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request.py b/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request.py index 82c72cc37abf5..a70445c5efd4c 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request.py +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request.py @@ -15,14 +15,12 @@ # specific language governing permissions and limitations # under the License. # + import threading from typing import Any import torch -from iotdb.ainode.core.inference.strategy.abstract_inference_pipeline import ( - AbstractInferencePipeline, -) from iotdb.ainode.core.log import Logger from iotdb.ainode.core.util.atmoic_int import AtomicInt @@ -41,7 +39,6 @@ def __init__( req_id: str, model_id: str, inputs: torch.Tensor, - inference_pipeline: AbstractInferencePipeline, max_new_tokens: int = 96, **infer_kwargs, ): @@ -52,7 +49,6 @@ def __init__( self.model_id = model_id self.inputs = inputs self.infer_kwargs = infer_kwargs - self.inference_pipeline = inference_pipeline self.max_new_tokens = ( max_new_tokens # Number of time series data points to generate ) diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py b/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py index 6b054c91fe31c..fcaa4c7a7543b 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py @@ -25,19 +25,17 @@ import numpy as np import torch import torch.multiprocessing as mp -from transformers import PretrainedConfig from iotdb.ainode.core.config import AINodeDescriptor from iotdb.ainode.core.constant import INFERENCE_LOG_FILE_NAME_PREFIX_TEMPLATE from iotdb.ainode.core.inference.batcher.basic_batcher import BasicBatcher from iotdb.ainode.core.inference.inference_request import InferenceRequest +from iotdb.ainode.core.inference.pipeline import get_pipeline from iotdb.ainode.core.inference.request_scheduler.basic_request_scheduler import ( BasicRequestScheduler, ) from iotdb.ainode.core.log import Logger -from iotdb.ainode.core.manager.model_manager import ModelManager -from iotdb.ainode.core.model.model_enums import BuiltInModelType -from iotdb.ainode.core.model.model_info import ModelInfo +from iotdb.ainode.core.model.model_storage import ModelInfo from iotdb.ainode.core.util.gpu_mapping import convert_device_id_to_torch_device @@ -62,7 +60,6 @@ def __init__( pool_id: int, model_info: ModelInfo, device: str, - config: PretrainedConfig, request_queue: mp.Queue, result_queue: mp.Queue, ready_event, @@ -71,7 +68,6 @@ def __init__( super().__init__() self.pool_id = pool_id self.model_info = model_info - self.config = config self.pool_kwargs = pool_kwargs self.ready_event = ready_event self.device = convert_device_id_to_torch_device(device) @@ -86,8 +82,8 @@ def __init__( self._batcher = BasicBatcher() self._stop_event = mp.Event() - self._model = None - self._model_manager = None + # self._inference_pipeline = get_pipeline(self.model_info.model_id, self.device) + self._logger = None # Fix inference seed @@ -98,9 +94,6 @@ def __init__( def _activate_requests(self): requests = self._request_scheduler.schedule_activate() for request in requests: - request.inputs = request.inference_pipeline.preprocess_inputs( - request.inputs - ) request.mark_running() self._running_queue.put(request) self._logger.debug( @@ -123,66 +116,36 @@ def _step(self): for requests in grouped_requests: batch_inputs = self._batcher.batch_request(requests).to(self.device) - if self.model_info.model_type == BuiltInModelType.SUNDIAL.value: - batch_output = self._model.generate( - batch_inputs, - max_new_tokens=requests[0].max_new_tokens, - num_samples=10, - revin=True, - ) - - offset = 0 - for request in requests: - request.output_tensor = request.output_tensor.to(self.device) - cur_batch_size = request.batch_size - cur_output = batch_output[offset : offset + cur_batch_size] - offset += cur_batch_size - request.write_step_output(cur_output.mean(dim=1)) - - request.inference_pipeline.post_decode() - if request.is_finished(): - request.inference_pipeline.post_inference() - self._logger.debug( - f"[Inference][Device-{self.device}][Pool-{self.pool_id}][ID-{request.req_id}] Request is finished" - ) - # ensure the output tensor is on CPU before sending to result queue - request.output_tensor = request.output_tensor.cpu() - self._finished_queue.put(request) - else: - self._logger.debug( - f"[Inference][Device-{self.device}][Pool-{self.pool_id}][ID-{request.req_id}] Request is not finished, re-queueing" - ) - self._waiting_queue.put(request) - - elif self.model_info.model_type == BuiltInModelType.TIMER_XL.value: - batch_output = self._model.generate( - batch_inputs, - max_new_tokens=requests[0].max_new_tokens, - revin=True, - ) - - offset = 0 - for request in requests: - request.output_tensor = request.output_tensor.to(self.device) - cur_batch_size = request.batch_size - cur_output = batch_output[offset : offset + cur_batch_size] - offset += cur_batch_size - request.write_step_output(cur_output) - - request.inference_pipeline.post_decode() - if request.is_finished(): - request.inference_pipeline.post_inference() - self._logger.debug( - f"[Inference][Device-{self.device}][Pool-{self.pool_id}][ID-{request.req_id}] Request is finished" - ) - # ensure the output tensor is on CPU before sending to result queue - request.output_tensor = request.output_tensor.cpu() - self._finished_queue.put(request) - else: - self._logger.debug( - f"[Inference][Device-{self.device}][Pool-{self.pool_id}][ID-{request.req_id}] Request is not finished, re-queueing" - ) - self._waiting_queue.put(request) + batch_output = self._inference_pipeline.infer( + batch_inputs, + predict_length=requests[0].max_new_tokens, + # num_samples=10, + revin=True, + ) + offset = 0 + for request in requests: + request.output_tensor = request.output_tensor.to(self.device) + cur_batch_size = request.batch_size + cur_output = batch_output[offset : offset + cur_batch_size] + offset += cur_batch_size + # request.write_step_output(cur_output.mean(dim=1)) + request.write_step_output(cur_output) + + # self._inference_pipeline.post_decode() + if request.is_finished(): + # self._inference_pipeline.post_inference() + # ensure the output tensor is on CPU before sending to result queue + request.output_tensor = request.output_tensor.cpu() + self._finished_queue.put(request) + self._logger.debug( + f"[Inference][Device-{self.device}][Pool-{self.pool_id}][ID-{request.req_id}] Request is finished" + ) + else: + self._waiting_queue.put(request) + self._logger.debug( + f"[Inference][Device-{self.device}][Pool-{self.pool_id}][ID-{request.req_id}] Request is not finished, re-queueing" + ) + return def _requests_execute_loop(self): while not self._stop_event.is_set(): @@ -193,11 +156,8 @@ def run(self): self._logger = Logger( INFERENCE_LOG_FILE_NAME_PREFIX_TEMPLATE.format(self.device) ) - self._model_manager = ModelManager() self._request_scheduler.device = self.device - self._model = self._model_manager.load_model(self.model_info.model_id, {}).to( - self.device - ) + self._inference_pipeline = get_pipeline(self.model_info.model_id, self.device) self.ready_event.set() activate_daemon = threading.Thread( diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/__init__.py b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/__init__.py new file mode 100644 index 0000000000000..617c4e6738061 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/__init__.py @@ -0,0 +1,29 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from iotdb.ainode.core.inference.pipeline.sundial_pipeline import SundialPipeline +from iotdb.ainode.core.inference.pipeline.timerxl_pipeline import TimerxlPipeline + + +def get_pipeline(model_id, device): + if model_id == "timerxl": + return TimerxlPipeline(model_id, device=device) + elif model_id == "sundial": + return SundialPipeline(model_id, device=device) + else: + raise ValueError(f"Unsupported model_id: {model_id} with pipeline") diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py new file mode 100644 index 0000000000000..c413a92e82d62 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py @@ -0,0 +1,111 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from abc import ABC + +import torch + +from iotdb.ainode.core.exception import InferenceModelInternalError +from iotdb.ainode.core.manager.model_manager import get_model_manager + + +class BasicPipeline(ABC): + def __init__(self, model_id, **infer_kwargs): + self.model_id = model_id + self.device = infer_kwargs.get("device", "cpu") + # self.model = get_model_manager().load_model(model_id).to(self.device) + self.model = get_model_manager().load_model( + model_id, device_map=str(self.device) + ) + + def _preprocess(self, inputs): + """ + Preprocess the input before inference, including shape validation and value transformation. + """ + # TODO: Integrate with the data processing pipeline operators + pass + + def infer(self, inputs): + pass + + def _post_decode(self): + """ + Post-process the outputs after each decode step. + """ + pass + + def _postprocess(self, output: torch.Tensor): + """ + Post-process the outputs after the entire inference task. + """ + pass + + +class ForecastPipeline(BasicPipeline): + def __init__(self, model_id, **infer_kwargs): + super().__init__(model_id, infer_kwargs=infer_kwargs) + + def _preprocess(self, inputs): + if len(inputs.shape) != 2: + raise InferenceModelInternalError( + f"[Inference] Input shape must be: [batch_size, seq_len], but receives {inputs.shape}" + ) + return inputs + + def forecast(self, inputs, **infer_kwargs): + pass + + def _post_decode(self): + pass + + def _postprocess(self, output: torch.Tensor): + pass + + +class ClassificationPipeline(BasicPipeline): + def __init__(self, model_id, **infer_kwargs): + super().__init__(model_id, infer_kwargs=infer_kwargs) + + def _preprocess(self, inputs): + pass + + def classify(self, inputs, **kwargs): + pass + + def _post_decode(self): + pass + + def _postprocess(self, output: torch.Tensor): + pass + + +class ChatPipeline(BasicPipeline): + def __init__(self, model_id, **infer_kwargs): + super().__init__(model_id, infer_kwargs=infer_kwargs) + + def _preprocess(self, inputs): + pass + + def chat(self, inputs, **kwargs): + pass + + def _post_decode(self): + pass + + def _postprocess(self, output: torch.Tensor): + pass diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/sktime_pipeline.py b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/sktime_pipeline.py new file mode 100644 index 0000000000000..004222db7cad5 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/sktime_pipeline.py @@ -0,0 +1,62 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import numpy as np +import pandas as pd +import torch + +from iotdb.ainode.core.inference.pipeline.basic_pipeline import BasicPipeline + + +class SktimePipeline(BasicPipeline): + def __init__(self, model_id, **infer_kwargs): + super().__init__(model_id, infer_kwargs=infer_kwargs) + + def _preprocess(self, inputs): + return super()._preprocess(inputs) + + def infer(self, inputs, **infer_kwargs): + input_ids = self._preprocess(inputs) + + # Convert to pandas Series for sktime (sktime expects Series or DataFrame) + # Handle batch dimension: if batch_size > 1, process each sample separately + if len(input_ids.shape) == 2 and input_ids.shape[0] > 1: + # Batch processing: convert each row to Series + outputs = [] + for i in range(input_ids.shape[0]): + series = pd.Series(input_ids[i].cpu().numpy() if isinstance(input_ids, torch.Tensor) else input_ids[i]) + output = self.model.generate(series) + outputs.append(output) + output = np.array(outputs) + else: + # Single sample: convert to Series + if isinstance(input_ids, torch.Tensor): + series = pd.Series(input_ids.squeeze().cpu().numpy()) + else: + series = pd.Series(input_ids.squeeze()) + output = self.model.generate(series) + # Add batch dimension if needed + if len(output.shape) == 1: + output = output[np.newaxis, :] + + return self._postprocess(output) + + def _postprocess(self, output): + if isinstance(output, np.ndarray): + return torch.from_numpy(output).float() + return output diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/sundial_pipeline.py b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/sundial_pipeline.py new file mode 100644 index 0000000000000..8d0909954bf24 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/sundial_pipeline.py @@ -0,0 +1,48 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import pandas as pd +import torch + +from iotdb.ainode.core.inference.pipeline.basic_pipeline import ForecastPipeline +from iotdb.ainode.core.util.serde import convert_to_binary + + +class SundialPipeline(ForecastPipeline): + def __init__(self, model_id, **infer_kwargs): + super().__init__(model_id, infer_kwargs=infer_kwargs) + + def _preprocess(self, inputs): + return super()._preprocess(inputs) + + def infer(self, inputs, **infer_kwargs): + predict_length = infer_kwargs.get("predict_length", 96) + num_samples = infer_kwargs.get("num_samples", 10) + revin = infer_kwargs.get("revin", True) + + input_ids = self._preprocess(inputs) + output = self.model.generate( + input_ids, + max_new_tokens=predict_length, + num_samples=num_samples, + revin=revin, + ) + return self._postprocess(output) + + def _postprocess(self, output: torch.Tensor): + return output.mean(dim=1) diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/timerxl_pipeline.py b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/timerxl_pipeline.py new file mode 100644 index 0000000000000..cf6d35c805c5d --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/timerxl_pipeline.py @@ -0,0 +1,44 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import pandas as pd +import torch + +from iotdb.ainode.core.inference.pipeline.basic_pipeline import ForecastPipeline +from iotdb.ainode.core.util.serde import convert_to_binary + + +class TimerxlPipeline(ForecastPipeline): + def __init__(self, model_id, **infer_kwargs): + super().__init__(model_id, infer_kwargs=infer_kwargs) + + def _preprocess(self, inputs): + return super()._preprocess(inputs) + + def infer(self, inputs, **infer_kwargs): + predict_length = infer_kwargs.get("predict_length", 96) + revin = infer_kwargs.get("revin", True) + + input_ids = self._preprocess(inputs) + output = self.model.generate( + input_ids, max_new_tokens=predict_length, revin=revin + ) + return self._postprocess(output) + + def _postprocess(self, output: torch.Tensor): + return output diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py index 54580402ec293..ca9680dcf21e2 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py @@ -40,16 +40,14 @@ ScaleActionType, ) from iotdb.ainode.core.log import Logger -from iotdb.ainode.core.manager.model_manager import ModelManager -from iotdb.ainode.core.model.model_enums import BuiltInModelType -from iotdb.ainode.core.model.sundial.configuration_sundial import SundialConfig -from iotdb.ainode.core.model.timerxl.configuration_timer import TimerConfig +from iotdb.ainode.core.manager.model_manager import get_model_manager from iotdb.ainode.core.util.atmoic_int import AtomicInt from iotdb.ainode.core.util.batch_executor import BatchExecutor from iotdb.ainode.core.util.decorator import synchronized from iotdb.ainode.core.util.thread_name import ThreadName logger = Logger() +MODEL_MANAGER = get_model_manager() class PoolController: @@ -58,7 +56,6 @@ class PoolController: """ def __init__(self, result_queue: mp.Queue): - self._model_manager = ModelManager() # structure: {model_id: {device_id: PoolGroup}} self._request_pool_map: Dict[str, Dict[str, PoolGroup]] = {} self._new_pool_id = AtomicInt() @@ -78,7 +75,7 @@ def __init__(self, result_queue: mp.Queue): # =============== Pool Management =============== @synchronized(threading.Lock()) - def first_req_init(self, model_id: str): + def first_req_init(self, model_id: str, device): """ Initialize the pools when the first request for the given model_id arrives. """ @@ -110,17 +107,12 @@ def _first_pool_init(self, model_id: str, device_str: str): device = torch.device(device_str) device_id = device.index - if model_id == "sundial": - config = SundialConfig() - elif model_id == "timer_xl": - config = TimerConfig() first_queue = mp.Queue() ready_event = mp.Event() first_pool = InferenceRequestPool( pool_id=0, model_id=model_id, device=device_str, - config=config, request_queue=first_queue, result_queue=self._result_queue, ready_event=ready_event, @@ -195,7 +187,7 @@ def _load_model_task(self, model_id: str, device_id_list: list[str]): def _load_model_on_device_task(device_id: str): if not self.has_request_pools(model_id, device_id): actions = self._pool_scheduler.schedule_load_model_to_device( - self._model_manager.get_model_info(model_id), device_id + MODEL_MANAGER.get_model_info(model_id), device_id ) for action in actions: if action.action == ScaleActionType.SCALE_UP: @@ -222,7 +214,7 @@ def _unload_model_task(self, model_id: str, device_id_list: list[str]): def _unload_model_on_device_task(device_id: str): if self.has_request_pools(model_id, device_id): actions = self._pool_scheduler.schedule_unload_model_from_device( - self._model_manager.get_model_info(model_id), device_id + MODEL_MANAGER.get_model_info(model_id), device_id ) for action in actions: if action.action == ScaleActionType.SCALE_DOWN: @@ -255,30 +247,20 @@ def _expand_pools_on_device(self, model_id: str, device_id: str, count: int): """ def _expand_pool_on_device(*_): - result_queue = mp.Queue() + request_queue = mp.Queue() pool_id = self._new_pool_id.get_and_increment() - model_info = self._model_manager.get_model_info(model_id) - model_type = model_info.model_type - if model_type == BuiltInModelType.SUNDIAL.value: - config = SundialConfig() - elif model_type == BuiltInModelType.TIMER_XL.value: - config = TimerConfig() - else: - raise InferenceModelInternalError( - f"Unsupported model type {model_type} for loading model {model_id}" - ) + model_info = MODEL_MANAGER.get_model_info(model_id) pool = InferenceRequestPool( pool_id=pool_id, model_info=model_info, device=device_id, - config=config, - request_queue=result_queue, + request_queue=request_queue, result_queue=self._result_queue, ready_event=mp.Event(), ) pool.start() - self._register_pool(model_id, device_id, pool_id, pool, result_queue) - if not pool.ready_event.wait(timeout=300): + self._register_pool(model_id, device_id, pool_id, pool, request_queue) + if not pool.ready_event.wait(timeout=30): logger.error( f"[Inference][Device-{device_id}][Pool-{pool_id}] Pool failed to be ready in time" ) diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py index 6a2bd2b619aa7..63cd1829dde08 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py @@ -28,7 +28,7 @@ ScaleActionType, ) from iotdb.ainode.core.log import Logger -from iotdb.ainode.core.manager.model_manager import ModelManager +from iotdb.ainode.core.manager.model_manager import get_model_manager from iotdb.ainode.core.manager.utils import ( INFERENCE_EXTRA_MEMORY_RATIO, INFERENCE_MEMORY_USAGE_RATIO, @@ -36,11 +36,13 @@ estimate_pool_size, evaluate_system_resources, ) -from iotdb.ainode.core.model.model_info import BUILT_IN_LTSM_MAP, ModelInfo +from iotdb.ainode.core.model.model_info import ModelInfo from iotdb.ainode.core.util.gpu_mapping import convert_device_id_to_torch_device logger = Logger() +MODEL_MANAGER = get_model_manager() + def _estimate_shared_pool_size_by_total_mem( device: torch.device, @@ -63,7 +65,7 @@ def _estimate_shared_pool_size_by_total_mem( mem_usages: Dict[str, float] = {} for model_info in all_models: mem_usages[model_info.model_id] = ( - MODEL_MEM_USAGE_MAP[model_info.model_type] * INFERENCE_EXTRA_MEMORY_RATIO + MODEL_MEM_USAGE_MAP[model_info.model_id] * INFERENCE_EXTRA_MEMORY_RATIO ) # Evaluate system resources and get TOTAL memory diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/strategy/abstract_inference_pipeline.py b/iotdb-core/ainode/iotdb/ainode/core/inference/strategy/abstract_inference_pipeline.py deleted file mode 100644 index 2300169a6ee93..0000000000000 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/strategy/abstract_inference_pipeline.py +++ /dev/null @@ -1,60 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# - -from abc import ABC, abstractmethod - -import torch - - -class AbstractInferencePipeline(ABC): - """ - Abstract assistance strategy class for model inference. - This class shall define the interface process for specific model. - """ - - def __init__(self, model_config, **infer_kwargs): - self.model_config = model_config - self.infer_kwargs = infer_kwargs - - @abstractmethod - def preprocess_inputs(self, inputs: torch.Tensor): - """ - Preprocess the inputs before inference, including shape validation and value transformation. - - Args: - inputs (torch.Tensor): The input tensor to be preprocessed. - - Returns: - torch.Tensor: The preprocessed input tensor. - """ - # TODO: Integrate with the data processing pipeline operators - pass - - @abstractmethod - def post_decode(self): - """ - Post-process the outputs after each decode step. - """ - pass - - @abstractmethod - def post_inference(self): - """ - Post-process the outputs after the entire inference task. - """ - pass diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/strategy/timer_sundial_inference_pipeline.py b/iotdb-core/ainode/iotdb/ainode/core/inference/strategy/timer_sundial_inference_pipeline.py deleted file mode 100644 index 17c88e32fb5a0..0000000000000 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/strategy/timer_sundial_inference_pipeline.py +++ /dev/null @@ -1,51 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# - -import torch - -from iotdb.ainode.core.exception import InferenceModelInternalError -from iotdb.ainode.core.inference.strategy.abstract_inference_pipeline import ( - AbstractInferencePipeline, -) -from iotdb.ainode.core.model.sundial.configuration_sundial import SundialConfig - - -class TimerSundialInferencePipeline(AbstractInferencePipeline): - """ - Strategy for Timer-Sundial model inference. - """ - - def __init__(self, model_config: SundialConfig, **infer_kwargs): - super().__init__(model_config, infer_kwargs=infer_kwargs) - - def preprocess_inputs(self, inputs: torch.Tensor): - super().preprocess_inputs(inputs) - if len(inputs.shape) != 2: - raise InferenceModelInternalError( - f"[Inference] Input shape must be: [batch_size, seq_len], but receives {inputs.shape}" - ) - # TODO: Disassemble and adapt with Sundial's ts_generation_mixin.py - return inputs - - def post_decode(self): - # TODO: Disassemble and adapt with Sundial's ts_generation_mixin.py - pass - - def post_inference(self): - # TODO: Disassemble and adapt with Sundial's ts_generation_mixin.py - pass diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/strategy/timerxl_inference_pipeline.py b/iotdb-core/ainode/iotdb/ainode/core/inference/strategy/timerxl_inference_pipeline.py deleted file mode 100644 index dc1dd304f68e8..0000000000000 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/strategy/timerxl_inference_pipeline.py +++ /dev/null @@ -1,51 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# - -import torch - -from iotdb.ainode.core.exception import InferenceModelInternalError -from iotdb.ainode.core.inference.strategy.abstract_inference_pipeline import ( - AbstractInferencePipeline, -) -from iotdb.ainode.core.model.timerxl.configuration_timer import TimerConfig - - -class TimerXLInferencePipeline(AbstractInferencePipeline): - """ - Strategy for Timer-XL model inference. - """ - - def __init__(self, model_config: TimerConfig, **infer_kwargs): - super().__init__(model_config, infer_kwargs=infer_kwargs) - - def preprocess_inputs(self, inputs: torch.Tensor): - super().preprocess_inputs(inputs) - if len(inputs.shape) != 2: - raise InferenceModelInternalError( - f"[Inference] Input shape must be: [batch_size, seq_len], but receives {inputs.shape}" - ) - # Considering that we are currently using the generate function interface, it seems that no pre-processing is required - return inputs - - def post_decode(self): - # Considering that we are currently using the generate function interface, it seems that no post-processing is required - pass - - def post_inference(self): - # Considering that we are currently using the generate function interface, it seems that no post-processing is required - pass diff --git a/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py b/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py index a67d576b0ec8c..6f022f0ddd1e4 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py +++ b/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py @@ -37,21 +37,11 @@ InferenceRequest, InferenceRequestProxy, ) +from iotdb.ainode.core.inference.pipeline import get_pipeline from iotdb.ainode.core.inference.pool_controller import PoolController -from iotdb.ainode.core.inference.strategy.timer_sundial_inference_pipeline import ( - TimerSundialInferencePipeline, -) -from iotdb.ainode.core.inference.strategy.timerxl_inference_pipeline import ( - TimerXLInferencePipeline, -) from iotdb.ainode.core.inference.utils import generate_req_id from iotdb.ainode.core.log import Logger -from iotdb.ainode.core.manager.model_manager import ModelManager -from iotdb.ainode.core.model.model_enums import BuiltInModelType -from iotdb.ainode.core.model.sundial.configuration_sundial import SundialConfig -from iotdb.ainode.core.model.sundial.modeling_sundial import SundialForPrediction -from iotdb.ainode.core.model.timerxl.configuration_timer import TimerConfig -from iotdb.ainode.core.model.timerxl.modeling_timer import TimerForPrediction +from iotdb.ainode.core.manager.model_manager import get_model_manager from iotdb.ainode.core.rpc.status import get_status from iotdb.ainode.core.util.gpu_mapping import get_available_devices from iotdb.ainode.core.util.serde import convert_to_binary @@ -71,90 +61,13 @@ logger = Logger() -class InferenceStrategy(ABC): - def __init__(self, model): - self.model = model - - @abstractmethod - def infer(self, full_data, **kwargs): - pass - - -# [IoTDB] full data deserialized from iotdb is composed of [timestampList, valueList, length], -# we only get valueList currently. -class TimerXLStrategy(InferenceStrategy): - def infer(self, full_data, predict_length=96, **_): - data = full_data[1][0] - if data.dtype.byteorder not in ("=", "|"): - np_data = data.byteswap() - data = np_data.view(np_data.dtype.newbyteorder()) - seqs = torch.tensor(data).unsqueeze(0).float() - # TODO: unify model inference input - output = self.model.generate(seqs, max_new_tokens=predict_length, revin=True) - df = pd.DataFrame(output[0]) - return convert_to_binary(df) - - -class SundialStrategy(InferenceStrategy): - def infer(self, full_data, predict_length=96, **_): - data = full_data[1][0] - if data.dtype.byteorder not in ("=", "|"): - np_data = data.byteswap() - data = np_data.view(np_data.dtype.newbyteorder()) - seqs = torch.tensor(data).unsqueeze(0).float() - # TODO: unify model inference input - output = self.model.generate( - seqs, max_new_tokens=predict_length, num_samples=10, revin=True - ) - df = pd.DataFrame(output[0].mean(dim=0)) - return convert_to_binary(df) - - -class BuiltInStrategy(InferenceStrategy): - def infer(self, full_data, **_): - data = pd.DataFrame(full_data[1]).T - output = self.model.inference(data) - df = pd.DataFrame(output) - return convert_to_binary(df) - - -class RegisteredStrategy(InferenceStrategy): - def infer(self, full_data, window_interval=None, window_step=None, **_): - _, dataset, _, length = full_data - if window_interval is None or window_step is None: - window_interval = length - window_step = float("inf") - - if window_interval <= 0 or window_step <= 0 or window_interval > length: - raise InvalidWindowArgumentError(window_interval, window_step, length) - - data = torch.tensor(dataset, dtype=torch.float32).unsqueeze(0).permute(0, 2, 1) - - times = int((length - window_interval) // window_step + 1) - results = [] - try: - for i in range(times): - start = 0 if window_step == float("inf") else i * window_step - end = start + window_interval - window = data[:, start:end, :] - out = self.model(window) - df = pd.DataFrame(out.squeeze(0).detach().numpy()) - results.append(df) - except Exception as e: - msg = runtime_error_extractor(str(e)) or str(e) - raise InferenceModelInternalError(msg) - - # concatenate or return first window for forecast - return [convert_to_binary(df) for df in results] - - class InferenceManager: WAITING_INTERVAL_IN_MS = ( AINodeDescriptor().get_config().get_ain_inference_batch_interval_in_ms() ) # How often to check for requests in the result queue def __init__(self): - self._model_manager = ModelManager() + self._model_manager = get_model_manager() self._model_mem_usage_map: Dict[str, int] = ( {} ) # store model memory usage for each model @@ -251,15 +164,6 @@ def _process_request(self, req): with self._result_wrapper_lock: del self._result_wrapper_map[req_id] - def _get_strategy(self, model_id, model): - if isinstance(model, TimerForPrediction): - return TimerXLStrategy(model) - if isinstance(model, SundialForPrediction): - return SundialStrategy(model) - if self._model_manager.model_storage.is_built_in_or_fine_tuned(model_id): - return BuiltInStrategy(model) - return RegisteredStrategy(model) - def _run( self, req, @@ -272,9 +176,17 @@ def _run( model_id = req.modelId try: raw = data_getter(req) + # full data deserialized from iotdb is composed of [timestampList, valueList, None, length], we only get valueList currently. full_data = deserializer(raw) - inference_attrs = extract_attrs(req) + # TODO: TSBlock -> Tensor codes should be unified + data = full_data[1][0] # get valueList in ndarray + if data.dtype.byteorder not in ("=", "|"): + np_data = data.byteswap() + data = np_data.view(np_data.dtype.newbyteorder()) + # the inputs should be on CPU before passing to the inference request + inputs = torch.tensor(data).unsqueeze(0).float().to("cpu") + inference_attrs = extract_attrs(req) predict_length = int(inference_attrs.pop("predict_length", 96)) if ( predict_length @@ -290,41 +202,20 @@ def _run( ) if self._pool_controller.has_request_pools(model_id): - # use request pool to accelerate inference when the model instance is already loaded. - # TODO: TSBlock -> Tensor codes should be unified - data = full_data[1][0] - if data.dtype.byteorder not in ("=", "|"): - np_data = data.byteswap() - data = np_data.view(np_data.dtype.newbyteorder()) - # the inputs should be on CPU before passing to the inference request - inputs = torch.tensor(data).unsqueeze(0).float().to("cpu") - model_type = self._model_manager.get_model_info(model_id).model_type - if model_type == BuiltInModelType.SUNDIAL.value: - inference_pipeline = TimerSundialInferencePipeline(SundialConfig()) - elif model_type == BuiltInModelType.TIMER_XL.value: - inference_pipeline = TimerXLInferencePipeline(TimerConfig()) - else: - raise InferenceModelInternalError( - f"Unsupported model_id: {model_id}" - ) infer_req = InferenceRequest( req_id=generate_req_id(), model_id=model_id, inputs=inputs, - inference_pipeline=inference_pipeline, max_new_tokens=predict_length, ) outputs = self._process_request(infer_req) outputs = convert_to_binary(pd.DataFrame(outputs[0])) else: - # load model - accel = str(inference_attrs.get("acceleration", "")).lower() == "true" - model = self._model_manager.load_model(model_id, inference_attrs, accel) - # inference by strategy - strategy = self._get_strategy(model_id, model) - outputs = strategy.infer( - full_data, predict_length=predict_length, **inference_attrs + inference_pipeline = get_pipeline(model_id, device="cpu") + outputs = inference_pipeline.infer( + inputs, predict_length=predict_length, **inference_attrs ) + outputs = convert_to_binary(pd.DataFrame(outputs[0])) # construct response status = get_status(TSStatusCode.SUCCESS_STATUS) diff --git a/iotdb-core/ainode/iotdb/ainode/core/manager/model_manager.py b/iotdb-core/ainode/iotdb/ainode/core/manager/model_manager.py index d84bca77c8430..a07552922ff36 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/manager/model_manager.py +++ b/iotdb-core/ainode/iotdb/ainode/core/manager/model_manager.py @@ -15,19 +15,17 @@ # specific language governing permissions and limitations # under the License. # -from typing import Callable, Dict -from torch import nn -from yaml import YAMLError +import os +from typing import Any, List, Optional +from iotdb.ainode.core.config import AINodeDescriptor from iotdb.ainode.core.constant import TSStatusCode -from iotdb.ainode.core.exception import BadConfigValueError, InvalidUriError +from iotdb.ainode.core.exception import BuiltInModelDeletionError from iotdb.ainode.core.log import Logger -from iotdb.ainode.core.model.model_enums import BuiltInModelType, ModelStates -from iotdb.ainode.core.model.model_info import ModelInfo -from iotdb.ainode.core.model.model_storage import ModelStorage +from iotdb.ainode.core.model.model_loader import ModelLoader +from iotdb.ainode.core.model.model_storage import ModelCategory, ModelInfo, ModelStorage from iotdb.ainode.core.rpc.status import get_status -from iotdb.ainode.core.util.decorator import singleton from iotdb.thrift.ainode.ttypes import ( TDeleteModelReq, TRegisterModelReq, @@ -40,130 +38,84 @@ logger = Logger() -@singleton class ModelManager: def __init__(self): - self.model_storage = ModelStorage() - - def register_model(self, req: TRegisterModelReq) -> TRegisterModelResp: - logger.info(f"register model {req.modelId} from {req.uri}") + self.models_dir = os.path.join( + os.getcwd(), AINodeDescriptor().get_config().get_ain_models_dir() + ) + self.storage = ModelStorage(models_dir=self.models_dir) + self.loader = ModelLoader(storage=self.storage) + + # Automatically discover all models + self._models = self.storage.discover_all() + + def register_model( + self, + req: TRegisterModelReq, + ) -> TRegisterModelResp: try: - configs, attributes = self.model_storage.register_model( - req.modelId, req.uri - ) - return TRegisterModelResp( - get_status(TSStatusCode.SUCCESS_STATUS), configs, attributes - ) - except InvalidUriError as e: - logger.warning(e) - return TRegisterModelResp( - get_status(TSStatusCode.INVALID_URI_ERROR, e.message) - ) - except BadConfigValueError as e: - logger.warning(e) - return TRegisterModelResp( - get_status(TSStatusCode.INVALID_INFERENCE_CONFIG, e.message) - ) - except YAMLError as e: - logger.warning(e) - if hasattr(e, "problem_mark"): - mark = e.problem_mark + success = self.storage.register_model(model_id=req.modelId, uri=req.uri) + if success: + return TRegisterModelResp(get_status(TSStatusCode.SUCCESS_STATUS)) + else: return TRegisterModelResp( - get_status( - TSStatusCode.INVALID_INFERENCE_CONFIG, - f"An error occurred while parsing the yaml file, " - f"at line {mark.line + 1} column {mark.column + 1}.", - ) + get_status(TSStatusCode.AINODE_INTERNAL_ERROR) ) + except ValueError as e: return TRegisterModelResp( - get_status( - TSStatusCode.INVALID_INFERENCE_CONFIG, - f"An error occurred while parsing the yaml file", - ) + get_status(TSStatusCode.INVALID_URI_ERROR, str(e)) ) except Exception as e: - logger.warning(e) return TRegisterModelResp(get_status(TSStatusCode.AINODE_INTERNAL_ERROR)) + def show_models(self, req: TShowModelsReq) -> TShowModelsResp: + return self.storage.show_models(req) + def delete_model(self, req: TDeleteModelReq) -> TSStatus: - logger.info(f"delete model {req.modelId}") try: - self.model_storage.delete_model(req.modelId) + self.storage.delete_model(req.modelId) return get_status(TSStatusCode.SUCCESS_STATUS) - except Exception as e: + except BuiltInModelDeletionError as e: logger.warning(e) return get_status(TSStatusCode.AINODE_INTERNAL_ERROR, str(e)) - - def load_model( - self, model_id: str, inference_attrs: Dict[str, str], acceleration: bool = False - ) -> Callable: - """ - Load the model with the given model_id. - """ - logger.info(f"Load model {model_id}") - try: - model = self.model_storage.load_model( - model_id, inference_attrs, acceleration - ) - logger.info(f"Model {model_id} loaded") - return model except Exception as e: - logger.error(f"Failed to load model {model_id}: {e}") - raise - - def save_model(self, model_id: str, model: nn.Module) -> TSStatus: - """ - Save the model using save_pretrained - """ - logger.info(f"Saving model {model_id}") - try: - self.model_storage.save_model(model_id, model) - logger.info(f"Saving model {model_id} successfully") - return get_status( - TSStatusCode.SUCCESS_STATUS, f"Model {model_id} saved successfully" - ) - except Exception as e: - logger.error(f"Save model failed: {e}") + logger.warning(e) return get_status(TSStatusCode.AINODE_INTERNAL_ERROR, str(e)) - def get_ckpt_path(self, model_id: str) -> str: - """ - Get the checkpoint path for a given model ID. - - Args: - model_id (str): The ID of the model. + def load_model(self, model_id: str, **kwargs) -> Any: + return self.loader.load_model(model_id=model_id, **kwargs) - Returns: - str: The path to the checkpoint file for the model. - """ - return self.model_storage.get_ckpt_path(model_id) + def get_model_info( + self, + model_id: str, + category: Optional[ModelCategory] = None, + ) -> Optional[ModelInfo]: + return self.storage.get_model_info(model_id, category) - def show_models(self, req: TShowModelsReq) -> TShowModelsResp: - return self.model_storage.show_models(req) + def get_model_infos( + self, + category: Optional[ModelCategory] = None, + model_type: Optional[str] = None, + ) -> List[ModelInfo]: + return self.storage.get_model_infos(category, model_type) - def register_built_in_model(self, model_info: ModelInfo): - self.model_storage.register_built_in_model(model_info) + def refresh(self): + """Refresh the model list (re-scan the file system)""" + self._models = self.storage.discover_all() - def get_model_info(self, model_id: str) -> ModelInfo: - return self.model_storage.get_model_info(model_id) + def get_registered_models(self) -> List[str]: + return self.storage.get_registered_models() - def update_model_state(self, model_id: str, state: ModelStates): - self.model_storage.update_model_state(model_id, state) + def is_model_registered(self, model_id: str) -> bool: + return self.storage.is_model_registered(model_id) - def get_built_in_model_type(self, model_id: str) -> BuiltInModelType: - """ - Get the type of the model with the given model_id. - """ - return self.model_storage.get_built_in_model_type(model_id) - def is_built_in_or_fine_tuned(self, model_id: str) -> bool: - """ - Check if the model_id corresponds to a built-in or fine-tuned model. +# Create a global model manager instance +_default_manager: Optional[ModelManager] = None - Args: - model_id (str): The ID of the model. - Returns: - bool: True if the model is built-in or fine_tuned, False otherwise. - """ - return self.model_storage.is_built_in_or_fine_tuned(model_id) +def get_model_manager() -> ModelManager: + global _default_manager + if _default_manager is None: + _default_manager = ModelManager() + return _default_manager diff --git a/iotdb-core/ainode/iotdb/ainode/core/manager/utils.py b/iotdb-core/ainode/iotdb/ainode/core/manager/utils.py index 0264e27331a86..297a2b832d7a8 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/manager/utils.py +++ b/iotdb-core/ainode/iotdb/ainode/core/manager/utils.py @@ -24,8 +24,7 @@ from iotdb.ainode.core.config import AINodeDescriptor from iotdb.ainode.core.exception import ModelNotExistError from iotdb.ainode.core.log import Logger -from iotdb.ainode.core.manager.model_manager import ModelManager -from iotdb.ainode.core.model.model_info import BUILT_IN_LTSM_MAP +from iotdb.ainode.core.manager.model_manager import get_model_manager logger = Logger() @@ -47,7 +46,7 @@ def measure_model_memory(device: torch.device, model_id: str) -> int: torch.cuda.synchronize(device) start = torch.cuda.memory_reserved(device) - model = ModelManager().load_model(model_id, {}).to(device) + model = get_model_manager().load_model(model_id, {}).to(device) torch.cuda.synchronize(device) end = torch.cuda.memory_reserved(device) usage = end - start @@ -80,8 +79,8 @@ def evaluate_system_resources(device: torch.device) -> dict: def estimate_pool_size(device: torch.device, model_id: str) -> int: - model_info = BUILT_IN_LTSM_MAP.get(model_id, None) - if model_info is None or model_info.model_type not in MODEL_MEM_USAGE_MAP: + model_info = get_model_manager.get_model_info(model_id) + if model_info is None or model_info.model_id not in MODEL_MEM_USAGE_MAP: logger.error( f"[Inference] Cannot estimate inference pool size on device: {device}, because model: {model_id} is not supported." ) @@ -90,9 +89,7 @@ def estimate_pool_size(device: torch.device, model_id: str) -> int: system_res = evaluate_system_resources(device) free_mem = system_res["free_mem"] - mem_usage = ( - MODEL_MEM_USAGE_MAP[model_info.model_type] * INFERENCE_EXTRA_MEMORY_RATIO - ) + mem_usage = MODEL_MEM_USAGE_MAP[model_info.model_id] * INFERENCE_EXTRA_MEMORY_RATIO size = int((free_mem * INFERENCE_MEMORY_USAGE_RATIO) // mem_usage) if size <= 0: logger.error( diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/built_in_model_factory.py b/iotdb-core/ainode/iotdb/ainode/core/model/built_in_model_factory.py deleted file mode 100644 index 3b55142350bad..0000000000000 --- a/iotdb-core/ainode/iotdb/ainode/core/model/built_in_model_factory.py +++ /dev/null @@ -1,1238 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -import os -from abc import abstractmethod -from typing import Callable, Dict, List - -import numpy as np -from huggingface_hub import hf_hub_download -from sklearn.preprocessing import MinMaxScaler -from sktime.detection.hmm_learn import GMMHMM, GaussianHMM -from sktime.detection.stray import STRAY -from sktime.forecasting.arima import ARIMA -from sktime.forecasting.exp_smoothing import ExponentialSmoothing -from sktime.forecasting.naive import NaiveForecaster -from sktime.forecasting.trend import STLForecaster - -from iotdb.ainode.core.config import AINodeDescriptor -from iotdb.ainode.core.constant import ( - MODEL_CONFIG_FILE_IN_JSON, - MODEL_WEIGHTS_FILE_IN_SAFETENSORS, - AttributeName, -) -from iotdb.ainode.core.exception import ( - BuiltInModelNotSupportError, - InferenceModelInternalError, - ListRangeException, - NumericalRangeException, - StringRangeException, - WrongAttributeTypeError, -) -from iotdb.ainode.core.log import Logger -from iotdb.ainode.core.model.model_enums import BuiltInModelType -from iotdb.ainode.core.model.model_info import TIMER_REPO_ID -from iotdb.ainode.core.model.sundial import modeling_sundial -from iotdb.ainode.core.model.timerxl import modeling_timer - -logger = Logger() - - -def _download_file_from_hf_if_necessary(local_dir: str, repo_id: str) -> bool: - weights_path = os.path.join(local_dir, MODEL_WEIGHTS_FILE_IN_SAFETENSORS) - config_path = os.path.join(local_dir, MODEL_CONFIG_FILE_IN_JSON) - if not os.path.exists(weights_path): - logger.info( - f"Model weights file not found at {weights_path}, downloading from HuggingFace..." - ) - try: - hf_hub_download( - repo_id=repo_id, - filename=MODEL_WEIGHTS_FILE_IN_SAFETENSORS, - local_dir=local_dir, - ) - logger.info(f"Got file to {weights_path}") - except Exception as e: - logger.error( - f"Failed to download model weights file to {local_dir} due to {e}" - ) - return False - if not os.path.exists(config_path): - logger.info( - f"Model config file not found at {config_path}, downloading from HuggingFace..." - ) - try: - hf_hub_download( - repo_id=repo_id, - filename=MODEL_CONFIG_FILE_IN_JSON, - local_dir=local_dir, - ) - logger.info(f"Got file to {config_path}") - except Exception as e: - logger.error( - f"Failed to download model config file to {local_dir} due to {e}" - ) - return False - return True - - -def download_built_in_ltsm_from_hf_if_necessary( - model_type: BuiltInModelType, local_dir: str -) -> bool: - """ - Download the built-in ltsm from HuggingFace repository when necessary. - - Return: - bool: True if the model is existed or downloaded successfully, False otherwise. - """ - repo_id = TIMER_REPO_ID[model_type] - if not _download_file_from_hf_if_necessary(local_dir, repo_id): - return False - return True - - -def get_model_attributes(model_type: BuiltInModelType): - if model_type == BuiltInModelType.ARIMA: - attribute_map = arima_attribute_map - elif model_type == BuiltInModelType.NAIVE_FORECASTER: - attribute_map = naive_forecaster_attribute_map - elif ( - model_type == BuiltInModelType.EXPONENTIAL_SMOOTHING - or model_type == BuiltInModelType.HOLTWINTERS - ): - attribute_map = exponential_smoothing_attribute_map - elif model_type == BuiltInModelType.STL_FORECASTER: - attribute_map = stl_forecaster_attribute_map - elif model_type == BuiltInModelType.GMM_HMM: - attribute_map = gmmhmm_attribute_map - elif model_type == BuiltInModelType.GAUSSIAN_HMM: - attribute_map = gaussian_hmm_attribute_map - elif model_type == BuiltInModelType.STRAY: - attribute_map = stray_attribute_map - elif model_type == BuiltInModelType.TIMER_XL: - attribute_map = timerxl_attribute_map - elif model_type == BuiltInModelType.SUNDIAL: - attribute_map = sundial_attribute_map - else: - raise BuiltInModelNotSupportError(model_type.value) - return attribute_map - - -def fetch_built_in_model( - model_type: BuiltInModelType, model_dir, inference_attrs: Dict[str, str] -) -> Callable: - """ - Fetch the built-in model according to its id and directory, not that this directory only contains model weights and config. - Args: - model_type: the type of the built-in model - model_dir: for huggingface models only, the directory where the model is stored - Returns: - model: the built-in model - """ - default_attributes = get_model_attributes(model_type) - # parse the attributes from inference_attrs - attributes = parse_attribute(inference_attrs, default_attributes) - - # build the built-in model - if model_type == BuiltInModelType.ARIMA: - model = ArimaModel(attributes) - elif ( - model_type == BuiltInModelType.EXPONENTIAL_SMOOTHING - or model_type == BuiltInModelType.HOLTWINTERS - ): - model = ExponentialSmoothingModel(attributes) - elif model_type == BuiltInModelType.NAIVE_FORECASTER: - model = NaiveForecasterModel(attributes) - elif model_type == BuiltInModelType.STL_FORECASTER: - model = STLForecasterModel(attributes) - elif model_type == BuiltInModelType.GMM_HMM: - model = GMMHMMModel(attributes) - elif model_type == BuiltInModelType.GAUSSIAN_HMM: - model = GaussianHmmModel(attributes) - elif model_type == BuiltInModelType.STRAY: - model = STRAYModel(attributes) - elif model_type == BuiltInModelType.TIMER_XL: - model = modeling_timer.TimerForPrediction.from_pretrained(model_dir) - elif model_type == BuiltInModelType.SUNDIAL: - model = modeling_sundial.SundialForPrediction.from_pretrained(model_dir) - else: - raise BuiltInModelNotSupportError(model_type.value) - - return model - - -class Attribute(object): - def __init__(self, name: str): - """ - Args: - name: the name of the attribute - """ - self._name = name - - @abstractmethod - def get_default_value(self): - raise NotImplementedError - - @abstractmethod - def validate_value(self, value): - raise NotImplementedError - - @abstractmethod - def parse(self, string_value: str): - raise NotImplementedError - - -class IntAttribute(Attribute): - def __init__( - self, - name: str, - default_value: int, - default_low: int, - default_high: int, - ): - super(IntAttribute, self).__init__(name) - self.__default_value = default_value - self.__default_low = default_low - self.__default_high = default_high - - def get_default_value(self): - return self.__default_value - - def validate_value(self, value): - if self.__default_low <= value <= self.__default_high: - return True - raise NumericalRangeException( - self._name, value, self.__default_low, self.__default_high - ) - - def parse(self, string_value: str): - try: - int_value = int(string_value) - except: - raise WrongAttributeTypeError(self._name, "int") - return int_value - - -class FloatAttribute(Attribute): - def __init__( - self, - name: str, - default_value: float, - default_low: float, - default_high: float, - ): - super(FloatAttribute, self).__init__(name) - self.__default_value = default_value - self.__default_low = default_low - self.__default_high = default_high - - def get_default_value(self): - return self.__default_value - - def validate_value(self, value): - if self.__default_low <= value <= self.__default_high: - return True - raise NumericalRangeException( - self._name, value, self.__default_low, self.__default_high - ) - - def parse(self, string_value: str): - try: - float_value = float(string_value) - except: - raise WrongAttributeTypeError(self._name, "float") - return float_value - - -class StringAttribute(Attribute): - def __init__(self, name: str, default_value: str, value_choices: List[str]): - super(StringAttribute, self).__init__(name) - self.__default_value = default_value - self.__value_choices = value_choices - - def get_default_value(self): - return self.__default_value - - def validate_value(self, value): - if value in self.__value_choices: - return True - raise StringRangeException(self._name, value, self.__value_choices) - - def parse(self, string_value: str): - return string_value - - -class BooleanAttribute(Attribute): - def __init__(self, name: str, default_value: bool): - super(BooleanAttribute, self).__init__(name) - self.__default_value = default_value - - def get_default_value(self): - return self.__default_value - - def validate_value(self, value): - if isinstance(value, bool): - return True - raise WrongAttributeTypeError(self._name, "bool") - - def parse(self, string_value: str): - if string_value.lower() == "true": - return True - elif string_value.lower() == "false": - return False - else: - raise WrongAttributeTypeError(self._name, "bool") - - -class ListAttribute(Attribute): - def __init__(self, name: str, default_value: List, value_type): - """ - value_type is the type of the elements in the list, e.g. int, float, str - """ - super(ListAttribute, self).__init__(name) - self.__default_value = default_value - self.__value_type = value_type - self.__type_to_str = {str: "str", int: "int", float: "float"} - - def get_default_value(self): - return self.__default_value - - def validate_value(self, value): - if not isinstance(value, list): - raise WrongAttributeTypeError(self._name, "list") - for value_item in value: - if not isinstance(value_item, self.__value_type): - raise WrongAttributeTypeError(self._name, self.__value_type) - return True - - def parse(self, string_value: str): - try: - list_value = eval(string_value) - except: - raise WrongAttributeTypeError(self._name, "list") - if not isinstance(list_value, list): - raise WrongAttributeTypeError(self._name, "list") - for i in range(len(list_value)): - try: - list_value[i] = self.__value_type(list_value[i]) - except: - raise ListRangeException( - self._name, list_value, self.__type_to_str[self.__value_type] - ) - return list_value - - -class TupleAttribute(Attribute): - def __init__(self, name: str, default_value: tuple, value_type): - """ - value_type is the type of the elements in the list, e.g. int, float, str - """ - super(TupleAttribute, self).__init__(name) - self.__default_value = default_value - self.__value_type = value_type - self.__type_to_str = {str: "str", int: "int", float: "float"} - - def get_default_value(self): - return self.__default_value - - def validate_value(self, value): - if not isinstance(value, tuple): - raise WrongAttributeTypeError(self._name, "tuple") - for value_item in value: - if not isinstance(value_item, self.__value_type): - raise WrongAttributeTypeError(self._name, self.__value_type) - return True - - def parse(self, string_value: str): - try: - tuple_value = eval(string_value) - except: - raise WrongAttributeTypeError(self._name, "tuple") - if not isinstance(tuple_value, tuple): - raise WrongAttributeTypeError(self._name, "tuple") - list_value = list(tuple_value) - for i in range(len(list_value)): - try: - list_value[i] = self.__value_type(list_value[i]) - except: - raise ListRangeException( - self._name, list_value, self.__type_to_str[self.__value_type] - ) - tuple_value = tuple(list_value) - return tuple_value - - -def parse_attribute( - input_attributes: Dict[str, str], attribute_map: Dict[str, Attribute] -): - """ - Args: - input_attributes: a dict of attributes, where the key is the attribute name, the value is the string value of - the attribute - attribute_map: a dict of hyperparameters, where the key is the attribute name, the value is the Attribute - object - Returns: - a dict of attributes, where the key is the attribute name, the value is the parsed value of the attribute - """ - attributes = {} - for attribute_name in attribute_map: - # user specified the attribute - if attribute_name in input_attributes: - attribute = attribute_map[attribute_name] - value = attribute.parse(input_attributes[attribute_name]) - attribute.validate_value(value) - attributes[attribute_name] = value - # user did not specify the attribute, use the default value - else: - try: - attributes[attribute_name] = attribute_map[ - attribute_name - ].get_default_value() - except NotImplementedError as e: - logger.error(f"attribute {attribute_name} is not implemented.") - raise e - return attributes - - -sundial_attribute_map = { - AttributeName.INPUT_TOKEN_LEN.value: IntAttribute( - name=AttributeName.INPUT_TOKEN_LEN.value, - default_value=16, - default_low=1, - default_high=5000, - ), - AttributeName.HIDDEN_SIZE.value: IntAttribute( - name=AttributeName.HIDDEN_SIZE.value, - default_value=768, - default_low=1, - default_high=5000, - ), - AttributeName.INTERMEDIATE_SIZE.value: IntAttribute( - name=AttributeName.INTERMEDIATE_SIZE.value, - default_value=3072, - default_low=1, - default_high=5000, - ), - AttributeName.OUTPUT_TOKEN_LENS.value: ListAttribute( - name=AttributeName.OUTPUT_TOKEN_LENS.value, default_value=[720], value_type=int - ), - AttributeName.NUM_HIDDEN_LAYERS.value: IntAttribute( - name=AttributeName.NUM_HIDDEN_LAYERS.value, - default_value=12, - default_low=1, - default_high=16, - ), - AttributeName.NUM_ATTENTION_HEADS.value: IntAttribute( - name=AttributeName.NUM_ATTENTION_HEADS.value, - default_value=12, - default_low=1, - default_high=192, - ), - AttributeName.HIDDEN_ACT.value: StringAttribute( - name=AttributeName.HIDDEN_ACT.value, - default_value="silu", - value_choices=["relu", "gelu", "silu", "tanh"], - ), - AttributeName.USE_CACHE.value: BooleanAttribute( - name=AttributeName.USE_CACHE.value, - default_value=True, - ), - AttributeName.ROPE_THETA.value: IntAttribute( - name=AttributeName.ROPE_THETA.value, - default_value=10000, - default_low=1000, - default_high=50000, - ), - AttributeName.DROPOUT_RATE.value: FloatAttribute( - name=AttributeName.DROPOUT_RATE.value, - default_value=0.1, - default_low=0.0, - default_high=1.0, - ), - AttributeName.INITIALIZER_RANGE.value: FloatAttribute( - name=AttributeName.INITIALIZER_RANGE.value, - default_value=0.02, - default_low=0.0, - default_high=1.0, - ), - AttributeName.MAX_POSITION_EMBEDDINGS.value: IntAttribute( - name=AttributeName.MAX_POSITION_EMBEDDINGS.value, - default_value=10000, - default_low=1, - default_high=50000, - ), - AttributeName.FLOW_LOSS_DEPTH.value: IntAttribute( - name=AttributeName.FLOW_LOSS_DEPTH.value, - default_value=3, - default_low=1, - default_high=50, - ), - AttributeName.NUM_SAMPLING_STEPS.value: IntAttribute( - name=AttributeName.NUM_SAMPLING_STEPS.value, - default_value=50, - default_low=1, - default_high=5000, - ), - AttributeName.DIFFUSION_BATCH_MUL.value: IntAttribute( - name=AttributeName.DIFFUSION_BATCH_MUL.value, - default_value=4, - default_low=1, - default_high=5000, - ), - AttributeName.CKPT_PATH.value: StringAttribute( - name=AttributeName.CKPT_PATH.value, - default_value=os.path.join( - os.getcwd(), - AINodeDescriptor().get_config().get_ain_models_dir(), - "weights", - "sundial", - ), - value_choices=[""], - ), -} - -timerxl_attribute_map = { - AttributeName.INPUT_TOKEN_LEN.value: IntAttribute( - name=AttributeName.INPUT_TOKEN_LEN.value, - default_value=96, - default_low=1, - default_high=5000, - ), - AttributeName.HIDDEN_SIZE.value: IntAttribute( - name=AttributeName.HIDDEN_SIZE.value, - default_value=1024, - default_low=1, - default_high=5000, - ), - AttributeName.INTERMEDIATE_SIZE.value: IntAttribute( - name=AttributeName.INTERMEDIATE_SIZE.value, - default_value=2048, - default_low=1, - default_high=5000, - ), - AttributeName.OUTPUT_TOKEN_LENS.value: ListAttribute( - name=AttributeName.OUTPUT_TOKEN_LENS.value, default_value=[96], value_type=int - ), - AttributeName.NUM_HIDDEN_LAYERS.value: IntAttribute( - name=AttributeName.NUM_HIDDEN_LAYERS.value, - default_value=8, - default_low=1, - default_high=16, - ), - AttributeName.NUM_ATTENTION_HEADS.value: IntAttribute( - name=AttributeName.NUM_ATTENTION_HEADS.value, - default_value=8, - default_low=1, - default_high=192, - ), - AttributeName.HIDDEN_ACT.value: StringAttribute( - name=AttributeName.HIDDEN_ACT.value, - default_value="silu", - value_choices=["relu", "gelu", "silu", "tanh"], - ), - AttributeName.USE_CACHE.value: BooleanAttribute( - name=AttributeName.USE_CACHE.value, - default_value=True, - ), - AttributeName.ROPE_THETA.value: IntAttribute( - name=AttributeName.ROPE_THETA.value, - default_value=10000, - default_low=1000, - default_high=50000, - ), - AttributeName.ATTENTION_DROPOUT.value: FloatAttribute( - name=AttributeName.ATTENTION_DROPOUT.value, - default_value=0.0, - default_low=0.0, - default_high=1.0, - ), - AttributeName.INITIALIZER_RANGE.value: FloatAttribute( - name=AttributeName.INITIALIZER_RANGE.value, - default_value=0.02, - default_low=0.0, - default_high=1.0, - ), - AttributeName.MAX_POSITION_EMBEDDINGS.value: IntAttribute( - name=AttributeName.MAX_POSITION_EMBEDDINGS.value, - default_value=10000, - default_low=1, - default_high=50000, - ), - AttributeName.CKPT_PATH.value: StringAttribute( - name=AttributeName.CKPT_PATH.value, - default_value=os.path.join( - os.getcwd(), - AINodeDescriptor().get_config().get_ain_models_dir(), - "weights", - "timerxl", - "model.safetensors", - ), - value_choices=[""], - ), -} - -# built-in sktime model attributes -# NaiveForecaster -naive_forecaster_attribute_map = { - AttributeName.PREDICT_LENGTH.value: IntAttribute( - name=AttributeName.PREDICT_LENGTH.value, - default_value=1, - default_low=1, - default_high=5000, - ), - AttributeName.STRATEGY.value: StringAttribute( - name=AttributeName.STRATEGY.value, - default_value="last", - value_choices=["last", "mean"], - ), - AttributeName.SP.value: IntAttribute( - name=AttributeName.SP.value, default_value=1, default_low=1, default_high=5000 - ), -} -# ExponentialSmoothing -exponential_smoothing_attribute_map = { - AttributeName.PREDICT_LENGTH.value: IntAttribute( - name=AttributeName.PREDICT_LENGTH.value, - default_value=1, - default_low=1, - default_high=5000, - ), - AttributeName.DAMPED_TREND.value: BooleanAttribute( - name=AttributeName.DAMPED_TREND.value, - default_value=False, - ), - AttributeName.INITIALIZATION_METHOD.value: StringAttribute( - name=AttributeName.INITIALIZATION_METHOD.value, - default_value="estimated", - value_choices=["estimated", "heuristic", "legacy-heuristic", "known"], - ), - AttributeName.OPTIMIZED.value: BooleanAttribute( - name=AttributeName.OPTIMIZED.value, - default_value=True, - ), - AttributeName.REMOVE_BIAS.value: BooleanAttribute( - name=AttributeName.REMOVE_BIAS.value, - default_value=False, - ), - AttributeName.USE_BRUTE.value: BooleanAttribute( - name=AttributeName.USE_BRUTE.value, - default_value=False, - ), -} -# Arima -arima_attribute_map = { - AttributeName.PREDICT_LENGTH.value: IntAttribute( - name=AttributeName.PREDICT_LENGTH.value, - default_value=1, - default_low=1, - default_high=5000, - ), - AttributeName.ORDER.value: TupleAttribute( - name=AttributeName.ORDER.value, default_value=(1, 0, 0), value_type=int - ), - AttributeName.SEASONAL_ORDER.value: TupleAttribute( - name=AttributeName.SEASONAL_ORDER.value, - default_value=(0, 0, 0, 0), - value_type=int, - ), - AttributeName.METHOD.value: StringAttribute( - name=AttributeName.METHOD.value, - default_value="lbfgs", - value_choices=["lbfgs", "bfgs", "newton", "nm", "cg", "ncg", "powell"], - ), - AttributeName.MAXITER.value: IntAttribute( - name=AttributeName.MAXITER.value, - default_value=1, - default_low=1, - default_high=5000, - ), - AttributeName.SUPPRESS_WARNINGS.value: BooleanAttribute( - name=AttributeName.SUPPRESS_WARNINGS.value, - default_value=True, - ), - AttributeName.OUT_OF_SAMPLE_SIZE.value: IntAttribute( - name=AttributeName.OUT_OF_SAMPLE_SIZE.value, - default_value=0, - default_low=0, - default_high=5000, - ), - AttributeName.SCORING.value: StringAttribute( - name=AttributeName.SCORING.value, - default_value="mse", - value_choices=["mse", "mae", "rmse", "mape", "smape", "rmsle", "r2"], - ), - AttributeName.WITH_INTERCEPT.value: BooleanAttribute( - name=AttributeName.WITH_INTERCEPT.value, - default_value=True, - ), - AttributeName.TIME_VARYING_REGRESSION.value: BooleanAttribute( - name=AttributeName.TIME_VARYING_REGRESSION.value, - default_value=False, - ), - AttributeName.ENFORCE_STATIONARITY.value: BooleanAttribute( - name=AttributeName.ENFORCE_STATIONARITY.value, - default_value=True, - ), - AttributeName.ENFORCE_INVERTIBILITY.value: BooleanAttribute( - name=AttributeName.ENFORCE_INVERTIBILITY.value, - default_value=True, - ), - AttributeName.SIMPLE_DIFFERENCING.value: BooleanAttribute( - name=AttributeName.SIMPLE_DIFFERENCING.value, - default_value=False, - ), - AttributeName.MEASUREMENT_ERROR.value: BooleanAttribute( - name=AttributeName.MEASUREMENT_ERROR.value, - default_value=False, - ), - AttributeName.MLE_REGRESSION.value: BooleanAttribute( - name=AttributeName.MLE_REGRESSION.value, - default_value=True, - ), - AttributeName.HAMILTON_REPRESENTATION.value: BooleanAttribute( - name=AttributeName.HAMILTON_REPRESENTATION.value, - default_value=False, - ), - AttributeName.CONCENTRATE_SCALE.value: BooleanAttribute( - name=AttributeName.CONCENTRATE_SCALE.value, - default_value=False, - ), -} -# STLForecaster -stl_forecaster_attribute_map = { - AttributeName.PREDICT_LENGTH.value: IntAttribute( - name=AttributeName.PREDICT_LENGTH.value, - default_value=1, - default_low=1, - default_high=5000, - ), - AttributeName.SP.value: IntAttribute( - name=AttributeName.SP.value, default_value=2, default_low=1, default_high=5000 - ), - AttributeName.SEASONAL.value: IntAttribute( - name=AttributeName.SEASONAL.value, - default_value=7, - default_low=1, - default_high=5000, - ), - AttributeName.SEASONAL_DEG.value: IntAttribute( - name=AttributeName.SEASONAL_DEG.value, - default_value=1, - default_low=0, - default_high=5000, - ), - AttributeName.TREND_DEG.value: IntAttribute( - name=AttributeName.TREND_DEG.value, - default_value=1, - default_low=0, - default_high=5000, - ), - AttributeName.LOW_PASS_DEG.value: IntAttribute( - name=AttributeName.LOW_PASS_DEG.value, - default_value=1, - default_low=0, - default_high=5000, - ), - AttributeName.SEASONAL_JUMP.value: IntAttribute( - name=AttributeName.SEASONAL_JUMP.value, - default_value=1, - default_low=0, - default_high=5000, - ), - AttributeName.TREND_JUMP.value: IntAttribute( - name=AttributeName.TREND_JUMP.value, - default_value=1, - default_low=0, - default_high=5000, - ), - AttributeName.LOSS_PASS_JUMP.value: IntAttribute( - name=AttributeName.LOSS_PASS_JUMP.value, - default_value=1, - default_low=0, - default_high=5000, - ), -} - -# GAUSSIAN_HMM -gaussian_hmm_attribute_map = { - AttributeName.N_COMPONENTS.value: IntAttribute( - name=AttributeName.N_COMPONENTS.value, - default_value=1, - default_low=1, - default_high=5000, - ), - AttributeName.COVARIANCE_TYPE.value: StringAttribute( - name=AttributeName.COVARIANCE_TYPE.value, - default_value="diag", - value_choices=["spherical", "diag", "full", "tied"], - ), - AttributeName.MIN_COVAR.value: FloatAttribute( - name=AttributeName.MIN_COVAR.value, - default_value=1e-3, - default_low=-1e10, - default_high=1e10, - ), - AttributeName.STARTPROB_PRIOR.value: FloatAttribute( - name=AttributeName.STARTPROB_PRIOR.value, - default_value=1.0, - default_low=-1e10, - default_high=1e10, - ), - AttributeName.TRANSMAT_PRIOR.value: FloatAttribute( - name=AttributeName.TRANSMAT_PRIOR.value, - default_value=1.0, - default_low=-1e10, - default_high=1e10, - ), - AttributeName.MEANS_PRIOR.value: FloatAttribute( - name=AttributeName.MEANS_PRIOR.value, - default_value=0.0, - default_low=-1e10, - default_high=1e10, - ), - AttributeName.MEANS_WEIGHT.value: FloatAttribute( - name=AttributeName.MEANS_WEIGHT.value, - default_value=0.0, - default_low=-1e10, - default_high=1e10, - ), - AttributeName.COVARS_PRIOR.value: FloatAttribute( - name=AttributeName.COVARS_PRIOR.value, - default_value=1e-2, - default_low=-1e10, - default_high=1e10, - ), - AttributeName.COVARS_WEIGHT.value: FloatAttribute( - name=AttributeName.COVARS_WEIGHT.value, - default_value=1.0, - default_low=-1e10, - default_high=1e10, - ), - AttributeName.ALGORITHM.value: StringAttribute( - name=AttributeName.ALGORITHM.value, - default_value="viterbi", - value_choices=["viterbi", "map"], - ), - AttributeName.N_ITER.value: IntAttribute( - name=AttributeName.N_ITER.value, - default_value=10, - default_low=1, - default_high=5000, - ), - AttributeName.TOL.value: FloatAttribute( - name=AttributeName.TOL.value, - default_value=1e-2, - default_low=-1e10, - default_high=1e10, - ), - AttributeName.PARAMS.value: StringAttribute( - name=AttributeName.PARAMS.value, - default_value="stmc", - value_choices=["stmc", "stm"], - ), - AttributeName.INIT_PARAMS.value: StringAttribute( - name=AttributeName.INIT_PARAMS.value, - default_value="stmc", - value_choices=["stmc", "stm"], - ), - AttributeName.IMPLEMENTATION.value: StringAttribute( - name=AttributeName.IMPLEMENTATION.value, - default_value="log", - value_choices=["log", "scaling"], - ), -} - -# GMMHMM -gmmhmm_attribute_map = { - AttributeName.N_COMPONENTS.value: IntAttribute( - name=AttributeName.N_COMPONENTS.value, - default_value=1, - default_low=1, - default_high=5000, - ), - AttributeName.N_MIX.value: IntAttribute( - name=AttributeName.N_MIX.value, - default_value=1, - default_low=1, - default_high=5000, - ), - AttributeName.MIN_COVAR.value: FloatAttribute( - name=AttributeName.MIN_COVAR.value, - default_value=1e-3, - default_low=-1e10, - default_high=1e10, - ), - AttributeName.STARTPROB_PRIOR.value: FloatAttribute( - name=AttributeName.STARTPROB_PRIOR.value, - default_value=1.0, - default_low=-1e10, - default_high=1e10, - ), - AttributeName.TRANSMAT_PRIOR.value: FloatAttribute( - name=AttributeName.TRANSMAT_PRIOR.value, - default_value=1.0, - default_low=-1e10, - default_high=1e10, - ), - AttributeName.WEIGHTS_PRIOR.value: FloatAttribute( - name=AttributeName.WEIGHTS_PRIOR.value, - default_value=1.0, - default_low=-1e10, - default_high=1e10, - ), - AttributeName.MEANS_PRIOR.value: FloatAttribute( - name=AttributeName.MEANS_PRIOR.value, - default_value=0.0, - default_low=-1e10, - default_high=1e10, - ), - AttributeName.MEANS_WEIGHT.value: FloatAttribute( - name=AttributeName.MEANS_WEIGHT.value, - default_value=0.0, - default_low=-1e10, - default_high=1e10, - ), - AttributeName.ALGORITHM.value: StringAttribute( - name=AttributeName.ALGORITHM.value, - default_value="viterbi", - value_choices=["viterbi", "map"], - ), - AttributeName.COVARIANCE_TYPE.value: StringAttribute( - name=AttributeName.COVARIANCE_TYPE.value, - default_value="diag", - value_choices=["sperical", "diag", "full", "tied"], - ), - AttributeName.N_ITER.value: IntAttribute( - name=AttributeName.N_ITER.value, - default_value=10, - default_low=1, - default_high=5000, - ), - AttributeName.TOL.value: FloatAttribute( - name=AttributeName.TOL.value, - default_value=1e-2, - default_low=-1e10, - default_high=1e10, - ), - AttributeName.INIT_PARAMS.value: StringAttribute( - name=AttributeName.INIT_PARAMS.value, - default_value="stmcw", - value_choices=[ - "s", - "t", - "m", - "c", - "w", - "st", - "sm", - "sc", - "sw", - "tm", - "tc", - "tw", - "mc", - "mw", - "cw", - "stm", - "stc", - "stw", - "smc", - "smw", - "scw", - "tmc", - "tmw", - "tcw", - "mcw", - "stmc", - "stmw", - "stcw", - "smcw", - "tmcw", - "stmcw", - ], - ), - AttributeName.PARAMS.value: StringAttribute( - name=AttributeName.PARAMS.value, - default_value="stmcw", - value_choices=[ - "s", - "t", - "m", - "c", - "w", - "st", - "sm", - "sc", - "sw", - "tm", - "tc", - "tw", - "mc", - "mw", - "cw", - "stm", - "stc", - "stw", - "smc", - "smw", - "scw", - "tmc", - "tmw", - "tcw", - "mcw", - "stmc", - "stmw", - "stcw", - "smcw", - "tmcw", - "stmcw", - ], - ), - AttributeName.IMPLEMENTATION.value: StringAttribute( - name=AttributeName.IMPLEMENTATION.value, - default_value="log", - value_choices=["log", "scaling"], - ), -} - -# STRAY -stray_attribute_map = { - AttributeName.ALPHA.value: FloatAttribute( - name=AttributeName.ALPHA.value, - default_value=0.01, - default_low=-1e10, - default_high=1e10, - ), - AttributeName.K.value: IntAttribute( - name=AttributeName.K.value, default_value=10, default_low=1, default_high=5000 - ), - AttributeName.KNN_ALGORITHM.value: StringAttribute( - name=AttributeName.KNN_ALGORITHM.value, - default_value="brute", - value_choices=["brute", "kd_tree", "ball_tree", "auto"], - ), - AttributeName.P.value: FloatAttribute( - name=AttributeName.P.value, - default_value=0.5, - default_low=-1e10, - default_high=1e10, - ), - AttributeName.SIZE_THRESHOLD.value: IntAttribute( - name=AttributeName.SIZE_THRESHOLD.value, - default_value=50, - default_low=1, - default_high=5000, - ), - AttributeName.OUTLIER_TAIL.value: StringAttribute( - name=AttributeName.OUTLIER_TAIL.value, - default_value="max", - value_choices=["min", "max"], - ), -} - - -class BuiltInModel(object): - def __init__(self, attributes): - self._attributes = attributes - self._model = None - - @abstractmethod - def inference(self, data): - raise NotImplementedError - - -class ArimaModel(BuiltInModel): - def __init__(self, attributes): - super(ArimaModel, self).__init__(attributes) - self._model = ARIMA( - order=attributes["order"], - seasonal_order=attributes["seasonal_order"], - method=attributes["method"], - suppress_warnings=attributes["suppress_warnings"], - maxiter=attributes["maxiter"], - out_of_sample_size=attributes["out_of_sample_size"], - scoring=attributes["scoring"], - with_intercept=attributes["with_intercept"], - time_varying_regression=attributes["time_varying_regression"], - enforce_stationarity=attributes["enforce_stationarity"], - enforce_invertibility=attributes["enforce_invertibility"], - simple_differencing=attributes["simple_differencing"], - measurement_error=attributes["measurement_error"], - mle_regression=attributes["mle_regression"], - hamilton_representation=attributes["hamilton_representation"], - concentrate_scale=attributes["concentrate_scale"], - ) - - def inference(self, data): - try: - predict_length = self._attributes["predict_length"] - self._model.fit(data) - output = self._model.predict(fh=range(predict_length)) - output = np.array(output, dtype=np.float64) - return output - except Exception as e: - raise InferenceModelInternalError(str(e)) - - -class ExponentialSmoothingModel(BuiltInModel): - def __init__(self, attributes): - super(ExponentialSmoothingModel, self).__init__(attributes) - self._model = ExponentialSmoothing( - damped_trend=attributes["damped_trend"], - initialization_method=attributes["initialization_method"], - optimized=attributes["optimized"], - remove_bias=attributes["remove_bias"], - use_brute=attributes["use_brute"], - ) - - def inference(self, data): - try: - predict_length = self._attributes["predict_length"] - self._model.fit(data) - output = self._model.predict(fh=range(predict_length)) - output = np.array(output, dtype=np.float64) - return output - except Exception as e: - raise InferenceModelInternalError(str(e)) - - -class NaiveForecasterModel(BuiltInModel): - def __init__(self, attributes): - super(NaiveForecasterModel, self).__init__(attributes) - self._model = NaiveForecaster( - strategy=attributes["strategy"], sp=attributes["sp"] - ) - - def inference(self, data): - try: - predict_length = self._attributes["predict_length"] - self._model.fit(data) - output = self._model.predict(fh=range(predict_length)) - output = np.array(output, dtype=np.float64) - return output - except Exception as e: - raise InferenceModelInternalError(str(e)) - - -class STLForecasterModel(BuiltInModel): - def __init__(self, attributes): - super(STLForecasterModel, self).__init__(attributes) - self._model = STLForecaster( - sp=attributes["sp"], - seasonal=attributes["seasonal"], - seasonal_deg=attributes["seasonal_deg"], - trend_deg=attributes["trend_deg"], - low_pass_deg=attributes["low_pass_deg"], - seasonal_jump=attributes["seasonal_jump"], - trend_jump=attributes["trend_jump"], - low_pass_jump=attributes["low_pass_jump"], - ) - - def inference(self, data): - try: - predict_length = self._attributes["predict_length"] - self._model.fit(data) - output = self._model.predict(fh=range(predict_length)) - output = np.array(output, dtype=np.float64) - return output - except Exception as e: - raise InferenceModelInternalError(str(e)) - - -class GMMHMMModel(BuiltInModel): - def __init__(self, attributes): - super(GMMHMMModel, self).__init__(attributes) - self._model = GMMHMM( - n_components=attributes["n_components"], - n_mix=attributes["n_mix"], - min_covar=attributes["min_covar"], - startprob_prior=attributes["startprob_prior"], - transmat_prior=attributes["transmat_prior"], - means_prior=attributes["means_prior"], - means_weight=attributes["means_weight"], - weights_prior=attributes["weights_prior"], - algorithm=attributes["algorithm"], - covariance_type=attributes["covariance_type"], - n_iter=attributes["n_iter"], - tol=attributes["tol"], - params=attributes["params"], - init_params=attributes["init_params"], - implementation=attributes["implementation"], - ) - - def inference(self, data): - try: - self._model.fit(data) - output = self._model.predict(data) - output = np.array(output, dtype=np.int32) - return output - except Exception as e: - raise InferenceModelInternalError(str(e)) - - -class GaussianHmmModel(BuiltInModel): - def __init__(self, attributes): - super(GaussianHmmModel, self).__init__(attributes) - self._model = GaussianHMM( - n_components=attributes["n_components"], - covariance_type=attributes["covariance_type"], - min_covar=attributes["min_covar"], - startprob_prior=attributes["startprob_prior"], - transmat_prior=attributes["transmat_prior"], - means_prior=attributes["means_prior"], - means_weight=attributes["means_weight"], - covars_prior=attributes["covars_prior"], - covars_weight=attributes["covars_weight"], - algorithm=attributes["algorithm"], - n_iter=attributes["n_iter"], - tol=attributes["tol"], - params=attributes["params"], - init_params=attributes["init_params"], - implementation=attributes["implementation"], - ) - - def inference(self, data): - try: - self._model.fit(data) - output = self._model.predict(data) - output = np.array(output, dtype=np.int32) - return output - except Exception as e: - raise InferenceModelInternalError(str(e)) - - -class STRAYModel(BuiltInModel): - def __init__(self, attributes): - super(STRAYModel, self).__init__(attributes) - self._model = STRAY( - alpha=attributes["alpha"], - k=attributes["k"], - knn_algorithm=attributes["knn_algorithm"], - p=attributes["p"], - size_threshold=attributes["size_threshold"], - outlier_tail=attributes["outlier_tail"], - ) - - def inference(self, data): - try: - data = MinMaxScaler().fit_transform(data) - output = self._model.fit_transform(data) - # change the output to int - output = np.array(output, dtype=np.int32) - return output - except Exception as e: - raise InferenceModelInternalError(str(e)) diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_enums.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_enums.py index 348f9924316b6..a6a234a1ab8f5 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/model_enums.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_enums.py @@ -16,37 +16,23 @@ # under the License. # from enum import Enum -from typing import List -class BuiltInModelType(Enum): - # forecast models - ARIMA = "Arima" - HOLTWINTERS = "HoltWinters" - EXPONENTIAL_SMOOTHING = "ExponentialSmoothing" - NAIVE_FORECASTER = "NaiveForecaster" - STL_FORECASTER = "StlForecaster" - - # anomaly detection models - GAUSSIAN_HMM = "GaussianHmm" - GMM_HMM = "GmmHmm" - STRAY = "Stray" - - # large time series models (LTSM) - TIMER_XL = "Timer-XL" - # sundial - SUNDIAL = "Timer-Sundial" +class ModelCategory(Enum): + BUILTIN = "builtin" + USER_DEFINED = "user_defined" + FINETUNE = "finetune" - @classmethod - def values(cls) -> List[str]: - return [item.value for item in cls] - @staticmethod - def is_built_in_model(model_type: str) -> bool: - """ - Check if the given model type corresponds to a built-in model. - """ - return model_type in BuiltInModelType.values() +class ModelStates(Enum): + INACTIVE = "inactive" + ACTIVATING = "activating" + ACTIVE = "active" + LOADING = "loading" + LOADED = "loaded" + DROPPING = "dropping" + TRAINING = "training" + FAILED = "failed" class ModelFileType(Enum): @@ -55,16 +41,18 @@ class ModelFileType(Enum): UNKNOWN = "unknown" -class ModelCategory(Enum): - BUILT_IN = "BUILT-IN" - FINE_TUNED = "FINE-TUNED" - USER_DEFINED = "USER-DEFINED" +class UriType(Enum): + REPO = "repo" + FILE = "file" -class ModelStates(Enum): - ACTIVE = "ACTIVE" - INACTIVE = "INACTIVE" - LOADING = "LOADING" - DROPPING = "DROPPING" - TRAINING = "TRAINING" - FAILED = "FAILED" +# Map for inferring which HuggingFace repository to download from based on model ID +REPO_ID_MAP = { + "timerxl": "thuml/timer-base-84m", + "sundial": "thuml/sundial-base-128m", + # More mappings can be added as needed +} + +# Model file constants +MODEL_CONFIG_FILE = "config.json" +MODEL_WEIGHTS_FILE = "model.safetensors" diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_factory.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_factory.py deleted file mode 100644 index 26d863156f379..0000000000000 --- a/iotdb-core/ainode/iotdb/ainode/core/model/model_factory.py +++ /dev/null @@ -1,291 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -import glob -import os -import shutil -from urllib.parse import urljoin - -import yaml - -from iotdb.ainode.core.constant import ( - MODEL_CONFIG_FILE_IN_YAML, - MODEL_WEIGHTS_FILE_IN_PT, -) -from iotdb.ainode.core.exception import BadConfigValueError, InvalidUriError -from iotdb.ainode.core.log import Logger -from iotdb.ainode.core.model.model_enums import ModelFileType -from iotdb.ainode.core.model.uri_utils import ( - UriType, - download_file, - download_snapshot_from_hf, -) -from iotdb.ainode.core.util.serde import get_data_type_byte_from_str -from iotdb.thrift.ainode.ttypes import TConfigs - -logger = Logger() - - -def fetch_model_by_uri( - uri_type: UriType, uri: str, storage_path: str, model_file_type: ModelFileType -): - """ - Fetch the model files from the specified URI. - - Args: - uri_type (UriType): type of the URI, either repo, file, http or https - uri (str): a network or a local path of the model to be registered - storage_path (str): path to save the whole model, including weights, config, codes, etc. - model_file_type (ModelFileType): The type of model file, either safetensors or pytorch - Returns: TODO: Will be removed in future - configs: TConfigs - attributes: str - """ - if uri_type == UriType.REPO or uri_type in [UriType.HTTP, UriType.HTTPS]: - return _fetch_model_from_network(uri, storage_path, model_file_type) - elif uri_type == UriType.FILE: - return _fetch_model_from_local(uri, storage_path, model_file_type) - else: - raise InvalidUriError(f"Invalid URI type: {uri_type}") - - -def _fetch_model_from_network( - uri: str, storage_path: str, model_file_type: ModelFileType -): - """ - Returns: TODO: Will be removed in future - configs: TConfigs - attributes: str - """ - if model_file_type == ModelFileType.SAFETENSORS: - download_snapshot_from_hf(uri, storage_path) - return _process_huggingface_files(storage_path) - - # TODO: The following codes might be refactored in future - # concat uri to get complete url - uri = uri if uri.endswith("/") else uri + "/" - target_model_path = urljoin(uri, MODEL_WEIGHTS_FILE_IN_PT) - target_config_path = urljoin(uri, MODEL_CONFIG_FILE_IN_YAML) - - # download config file - config_storage_path = os.path.join(storage_path, MODEL_CONFIG_FILE_IN_YAML) - download_file(target_config_path, config_storage_path) - - # read and parse config dict from config.yaml - with open(config_storage_path, "r", encoding="utf-8") as file: - config_dict = yaml.safe_load(file) - configs, attributes = _parse_inference_config(config_dict) - - # if config.yaml is correct, download model file - model_storage_path = os.path.join(storage_path, MODEL_WEIGHTS_FILE_IN_PT) - download_file(target_model_path, model_storage_path) - return configs, attributes - - -def _fetch_model_from_local( - uri: str, storage_path: str, model_file_type: ModelFileType -): - """ - Returns: TODO: Will be removed in future - configs: TConfigs - attributes: str - """ - if model_file_type == ModelFileType.SAFETENSORS: - # copy anything in the uri to local_dir - for file in os.listdir(uri): - shutil.copy(os.path.join(uri, file), storage_path) - return _process_huggingface_files(storage_path) - # concat uri to get complete path - target_model_path = os.path.join(uri, MODEL_WEIGHTS_FILE_IN_PT) - model_storage_path = os.path.join(storage_path, MODEL_WEIGHTS_FILE_IN_PT) - target_config_path = os.path.join(uri, MODEL_CONFIG_FILE_IN_YAML) - config_storage_path = os.path.join(storage_path, MODEL_CONFIG_FILE_IN_YAML) - - # check if file exist - exist_model_file = os.path.exists(target_model_path) - exist_config_file = os.path.exists(target_config_path) - - configs = None - attributes = None - if exist_model_file and exist_config_file: - # copy config.yaml - shutil.copy(target_config_path, config_storage_path) - logger.info( - f"copy file from {target_config_path} to {config_storage_path} success" - ) - - # read and parse config dict from config.yaml - with open(config_storage_path, "r", encoding="utf-8") as file: - config_dict = yaml.safe_load(file) - configs, attributes = _parse_inference_config(config_dict) - - # if config.yaml is correct, copy model file - shutil.copy(target_model_path, model_storage_path) - logger.info( - f"copy file from {target_model_path} to {model_storage_path} success" - ) - - elif not exist_model_file or not exist_config_file: - raise InvalidUriError(uri) - - return configs, attributes - - -def _parse_inference_config(config_dict): - """ - Args: - config_dict: dict - - configs: dict - - input_shape (list): input shape of the model and needs to be two-dimensional array like [96, 2] - - output_shape (list): output shape of the model and needs to be two-dimensional array like [96, 2] - - input_type (list): input type of the model and each element needs to be in ['bool', 'int32', 'int64', 'float32', 'float64', 'text'], default float64 - - output_type (list): output type of the model and each element needs to be in ['bool', 'int32', 'int64', 'float32', 'float64', 'text'], default float64 - - attributes: dict - Returns: - configs: TConfigs - attributes: str - """ - configs = config_dict["configs"] - - # check if input_shape and output_shape are two-dimensional array - if not ( - isinstance(configs["input_shape"], list) and len(configs["input_shape"]) == 2 - ): - raise BadConfigValueError( - "input_shape", - configs["input_shape"], - "input_shape should be a two-dimensional array.", - ) - if not ( - isinstance(configs["output_shape"], list) and len(configs["output_shape"]) == 2 - ): - raise BadConfigValueError( - "output_shape", - configs["output_shape"], - "output_shape should be a two-dimensional array.", - ) - - # check if input_shape and output_shape are positive integer - input_shape_is_positive_number = ( - isinstance(configs["input_shape"][0], int) - and isinstance(configs["input_shape"][1], int) - and configs["input_shape"][0] > 0 - and configs["input_shape"][1] > 0 - ) - if not input_shape_is_positive_number: - raise BadConfigValueError( - "input_shape", - configs["input_shape"], - "element in input_shape should be positive integer.", - ) - - output_shape_is_positive_number = ( - isinstance(configs["output_shape"][0], int) - and isinstance(configs["output_shape"][1], int) - and configs["output_shape"][0] > 0 - and configs["output_shape"][1] > 0 - ) - if not output_shape_is_positive_number: - raise BadConfigValueError( - "output_shape", - configs["output_shape"], - "element in output_shape should be positive integer.", - ) - - # check if input_type and output_type are one-dimensional array with right length - if "input_type" in configs and not ( - isinstance(configs["input_type"], list) - and len(configs["input_type"]) == configs["input_shape"][1] - ): - raise BadConfigValueError( - "input_type", - configs["input_type"], - "input_type should be a one-dimensional array and length of it should be equal to input_shape[1].", - ) - - if "output_type" in configs and not ( - isinstance(configs["output_type"], list) - and len(configs["output_type"]) == configs["output_shape"][1] - ): - raise BadConfigValueError( - "output_type", - configs["output_type"], - "output_type should be a one-dimensional array and length of it should be equal to output_shape[1].", - ) - - # parse input_type and output_type to byte - if "input_type" in configs: - input_type = [get_data_type_byte_from_str(x) for x in configs["input_type"]] - else: - input_type = [get_data_type_byte_from_str("float32")] * configs["input_shape"][ - 1 - ] - - if "output_type" in configs: - output_type = [get_data_type_byte_from_str(x) for x in configs["output_type"]] - else: - output_type = [get_data_type_byte_from_str("float32")] * configs[ - "output_shape" - ][1] - - # parse attributes - attributes = "" - if "attributes" in config_dict: - attributes = str(config_dict["attributes"]) - - return ( - TConfigs( - configs["input_shape"], configs["output_shape"], input_type, output_type - ), - attributes, - ) - - -def _process_huggingface_files(local_dir: str): - """ - TODO: Currently, we use this function to convert the model config from huggingface, we will refactor this in the future. - """ - config_file = None - for config_name in ["config.json", "model_config.json"]: - config_path = os.path.join(local_dir, config_name) - if os.path.exists(config_path): - config_file = config_path - break - - if not config_file: - raise InvalidUriError(f"No config.json found in {local_dir}") - - safetensors_files = glob.glob(os.path.join(local_dir, "*.safetensors")) - if not safetensors_files: - raise InvalidUriError(f"No .safetensors files found in {local_dir}") - - simple_config = { - "configs": { - "input_shape": [96, 1], - "output_shape": [96, 1], - "input_type": ["float32"], - "output_type": ["float32"], - }, - "attributes": { - "model_type": "huggingface_model", - "source_dir": local_dir, - "files": [os.path.basename(f) for f in safetensors_files], - }, - } - - configs, attributes = _parse_inference_config(simple_config) - return configs, attributes diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py index 167bfd76640d1..f36ad582a837b 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py @@ -15,53 +15,21 @@ # specific language governing permissions and limitations # under the License. # -import glob -import os -from iotdb.ainode.core.constant import ( - MODEL_CONFIG_FILE_IN_JSON, - MODEL_CONFIG_FILE_IN_YAML, - MODEL_WEIGHTS_FILE_IN_PT, - MODEL_WEIGHTS_FILE_IN_SAFETENSORS, -) -from iotdb.ainode.core.model.model_enums import ( - BuiltInModelType, - ModelCategory, - ModelFileType, - ModelStates, -) +from typing import Dict, List, Optional, Tuple +from iotdb.ainode.core.model.model_enums import ModelCategory, ModelStates -def get_model_file_type(model_path: str) -> ModelFileType: - """ - Determine the file type of the specified model directory. - """ - if _has_safetensors_format(model_path): - return ModelFileType.SAFETENSORS - elif _has_pytorch_format(model_path): - return ModelFileType.PYTORCH - else: - return ModelFileType.UNKNOWN - - -def _has_safetensors_format(path: str) -> bool: - """Check if directory contains safetensors files.""" - safetensors_files = glob.glob(os.path.join(path, MODEL_WEIGHTS_FILE_IN_SAFETENSORS)) - json_files = glob.glob(os.path.join(path, MODEL_CONFIG_FILE_IN_JSON)) - return len(safetensors_files) > 0 and len(json_files) > 0 - - -def _has_pytorch_format(path: str) -> bool: - """Check if directory contains pytorch files.""" - pt_files = glob.glob(os.path.join(path, MODEL_WEIGHTS_FILE_IN_PT)) - yaml_files = glob.glob(os.path.join(path, MODEL_CONFIG_FILE_IN_YAML)) - return len(pt_files) > 0 and len(yaml_files) > 0 - +# Map for inferring which HuggingFace repository to download from based on model ID +REPO_ID_MAP = { + "timerxl": "thuml/timer-base-84m", + "sundial": "thuml/sundial-base-128m", + # More mappings can be added as needed +} -def get_built_in_model_type(model_type: str) -> BuiltInModelType: - if not BuiltInModelType.is_built_in_model(model_type): - raise ValueError(f"Invalid built-in model type: {model_type}") - return BuiltInModelType(model_type) +# Model file constants +MODEL_CONFIG_FILE = "config.json" +MODEL_WEIGHTS_FILE = "model.safetensors" class ModelInfo: @@ -71,84 +39,91 @@ def __init__( model_type: str, category: ModelCategory, state: ModelStates, + path: str = "", + auto_map: Optional[Dict] = None, + _transformers_registered: bool = False, ): self.model_id = model_id self.model_type = model_type self.category = category self.state = state + self.path = path + self.auto_map = auto_map # If exists, indicates it's a Transformers model + self._transformers_registered = _transformers_registered # Internal flag: whether registered to Transformers + def __repr__(self): + return ( + f"ModelInfo(model_id={self.model_id}, model_type={self.model_type}, " + f"category={self.category.value}, state={self.state.value}, " + f"path={self.path}, has_auto_map={self.auto_map is not None})" + ) -TIMER_REPO_ID = { - BuiltInModelType.TIMER_XL: "thuml/timer-base-84m", - BuiltInModelType.SUNDIAL: "thuml/sundial-base-128m", -} -# Built-in machine learning models, they can be employed directly -BUILT_IN_MACHINE_LEARNING_MODEL_MAP = { +BUILTIN_SKTIME_MODEL_MAP = { # forecast models "arima": ModelInfo( model_id="arima", - model_type=BuiltInModelType.ARIMA.value, - category=ModelCategory.BUILT_IN, + model_type="sktime", + category=ModelCategory.BUILTIN, state=ModelStates.ACTIVE, ), "holtwinters": ModelInfo( model_id="holtwinters", - model_type=BuiltInModelType.HOLTWINTERS.value, - category=ModelCategory.BUILT_IN, + model_type="sktime", + category=ModelCategory.BUILTIN, state=ModelStates.ACTIVE, ), "exponential_smoothing": ModelInfo( model_id="exponential_smoothing", - model_type=BuiltInModelType.EXPONENTIAL_SMOOTHING.value, - category=ModelCategory.BUILT_IN, + model_type="sktime", + category=ModelCategory.BUILTIN, state=ModelStates.ACTIVE, ), "naive_forecaster": ModelInfo( model_id="naive_forecaster", - model_type=BuiltInModelType.NAIVE_FORECASTER.value, - category=ModelCategory.BUILT_IN, + model_type="sktime", + category=ModelCategory.BUILTIN, state=ModelStates.ACTIVE, ), "stl_forecaster": ModelInfo( model_id="stl_forecaster", - model_type=BuiltInModelType.STL_FORECASTER.value, - category=ModelCategory.BUILT_IN, + model_type="sktime", + category=ModelCategory.BUILTIN, state=ModelStates.ACTIVE, ), # anomaly detection models "gaussian_hmm": ModelInfo( model_id="gaussian_hmm", - model_type=BuiltInModelType.GAUSSIAN_HMM.value, - category=ModelCategory.BUILT_IN, + model_type="sktime", + category=ModelCategory.BUILTIN, state=ModelStates.ACTIVE, ), "gmm_hmm": ModelInfo( model_id="gmm_hmm", - model_type=BuiltInModelType.GMM_HMM.value, - category=ModelCategory.BUILT_IN, + model_type="sktime", + category=ModelCategory.BUILTIN, state=ModelStates.ACTIVE, ), "stray": ModelInfo( model_id="stray", - model_type=BuiltInModelType.STRAY.value, - category=ModelCategory.BUILT_IN, + model_type="sktime", + category=ModelCategory.BUILTIN, state=ModelStates.ACTIVE, ), } -# Built-in large time series models (LTSM), their weights are not included in AINode by default -BUILT_IN_LTSM_MAP = { - "timer_xl": ModelInfo( - model_id="timer_xl", - model_type=BuiltInModelType.TIMER_XL.value, - category=ModelCategory.BUILT_IN, - state=ModelStates.LOADING, +# Built-in huggingface transformers models, their weights are not included in AINode by default +BUILTIN_HF_TRANSFORMERS_MODEL_MAP = { + "timerxl": ModelInfo( + model_id="timerxl", + model_type="timer", + category=ModelCategory.BUILTIN, + state=ModelStates.INACTIVE, ), "sundial": ModelInfo( model_id="sundial", - model_type=BuiltInModelType.SUNDIAL.value, - category=ModelCategory.BUILT_IN, - state=ModelStates.LOADING, + model_type="sundial", + category=ModelCategory.BUILTIN, + state=ModelStates.INACTIVE, ), } diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py new file mode 100644 index 0000000000000..9b70b5f80401e --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py @@ -0,0 +1,170 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import os +from typing import Any + +import torch +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoModelForNextSentencePrediction, + AutoModelForSeq2SeqLM, + AutoModelForSequenceClassification, + AutoModelForTimeSeriesPrediction, + AutoModelForTokenClassification, +) + +from iotdb.ainode.core.exception import ModelNotExistError +from iotdb.ainode.core.log import Logger +from iotdb.ainode.core.model.model_enums import ModelCategory +from iotdb.ainode.core.model.model_info import ModelInfo +from iotdb.ainode.core.model.model_storage import ModelStorage +from iotdb.ainode.core.model.sktime.modeling_sktime import create_sktime_model + +logger = Logger() + + +class ModelLoader: + """Model loader - unified interface for loading different types of models""" + + def __init__(self, storage: ModelStorage): + self.storage = storage + + def load_model(self, model_id: str, **kwargs) -> Any: + # Lazy registration: if it's a Transformers model and not registered, register it first + model_info = self.storage.ensure_transformers_registered(model_id) + if not model_info: + logger.error( + f"Model {model_id} failed to register to Transformers, cannot load." + ) + return None + + if model_info.auto_map is not None: + model = self.load_model_from_transformers(model_info, **kwargs) + else: + if model_info.model_type == "sktime": + model = create_sktime_model(model_id) + else: + model = self.load_model_from_pt(model_info, **kwargs) + + logger.info(f"Model {model_id} loaded to device {model.device} successfully.") + return model + + def load_model_from_transformers(self, model_info: ModelInfo, **kwargs): + model_config, load_class = None, None + device_map = kwargs.get("device_map", "cpu") + trust_remote_code = kwargs.get("trust_remote_code", True) + train_from_scratch = kwargs.get("train_from_scratch", False) + + if model_info.category == ModelCategory.BUILTIN: + if model_info.model_id == "timerxl": + from iotdb.ainode.core.model.timerxl.configuration_timer import ( + TimerConfig, + ) + + model_config = TimerConfig() + from iotdb.ainode.core.model.timerxl.modeling_timer import ( + TimerForPrediction, + ) + + load_class = TimerForPrediction + elif model_info.model_id == "sundial": + from iotdb.ainode.core.model.sundial.configuration_sundial import ( + SundialConfig, + ) + + model_config = SundialConfig() + from iotdb.ainode.core.model.sundial.modeling_sundial import ( + SundialForPrediction, + ) + + load_class = SundialForPrediction + else: + logger.error( + f"Unsupported built-in Transformers model {model_info.model_id}." + ) + else: + model_config = AutoConfig.from_pretrained(model_info.path) + if ( + type(model_config) + in AutoModelForTimeSeriesPrediction._model_mapping.keys() + ): + load_class = AutoModelForTimeSeriesPrediction + elif ( + type(model_config) + in AutoModelForNextSentencePrediction._model_mapping.keys() + ): + load_class = AutoModelForNextSentencePrediction + elif type(model_config) in AutoModelForSeq2SeqLM._model_mapping.keys(): + load_class = AutoModelForSeq2SeqLM + elif ( + type(model_config) + in AutoModelForSequenceClassification._model_mapping.keys() + ): + load_class = AutoModelForSequenceClassification + elif ( + type(model_config) + in AutoModelForTokenClassification._model_mapping.keys() + ): + load_class = AutoModelForTokenClassification + else: + load_class = AutoModelForCausalLM + + if train_from_scratch: + model = load_class.from_config( + model_config, trust_remote_code=trust_remote_code, device_map=device_map + ) + else: + model = load_class.from_pretrained( + model_info.path, + trust_remote_code=trust_remote_code, + device_map=device_map, + ) + + return model + + def load_model_from_pt(self, model_info: ModelInfo, **kwargs): + device_map = kwargs.get("device_map", "cpu") + acceleration = kwargs.get("acceleration", False) + model_path = os.path.join(model_info.path, "model.pt") + if not os.path.exists(model_path): + logger.error(f"Model file not found at {model_path}.") + raise ModelNotExistError(model_path) + model = torch.jit.load(model_path) + if ( + isinstance(model, torch._dynamo.eval_frame.OptimizedModule) + or not acceleration + ): + return model + try: + model = torch.compile(model) + except Exception as e: + logger.warning(f"acceleration failed, fallback to normal mode: {str(e)}") + return model.to(device_map) + + def load_model_for_efficient_inference(self): + # TODO: An efficient model loading method for inference based on model_arguments + pass + + def load_model_for_powerful_finetune(self): + # TODO: An powerful model loading method for finetune based on model_arguments + pass + + def unload_model(self): + pass diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py index e346f569102e3..5ba67e158a5dc 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py @@ -17,46 +17,23 @@ # import concurrent.futures -import json import os import shutil -from collections.abc import Callable -from typing import Dict +from typing import List, Optional -import torch -from torch import nn +from huggingface_hub import hf_hub_download, snapshot_download +from transformers import AutoConfig, AutoModelForCausalLM -from iotdb.ainode.core.config import AINodeDescriptor -from iotdb.ainode.core.constant import ( - MODEL_CONFIG_FILE_IN_JSON, - MODEL_WEIGHTS_FILE_IN_PT, - TSStatusCode, -) -from iotdb.ainode.core.exception import ( - BuiltInModelDeletionError, - ModelNotExistError, - UnsupportedError, -) +from iotdb.ainode.core.constant import TSStatusCode +from iotdb.ainode.core.exception import BuiltInModelDeletionError from iotdb.ainode.core.log import Logger -from iotdb.ainode.core.model.built_in_model_factory import ( - download_built_in_ltsm_from_hf_if_necessary, - fetch_built_in_model, -) -from iotdb.ainode.core.model.model_enums import ( - BuiltInModelType, - ModelCategory, - ModelFileType, - ModelStates, -) -from iotdb.ainode.core.model.model_factory import fetch_model_by_uri +from iotdb.ainode.core.model.model_enums import REPO_ID_MAP, ModelCategory, ModelStates from iotdb.ainode.core.model.model_info import ( - BUILT_IN_LTSM_MAP, - BUILT_IN_MACHINE_LEARNING_MODEL_MAP, + BUILTIN_HF_TRANSFORMERS_MODEL_MAP, + BUILTIN_SKTIME_MODEL_MAP, ModelInfo, - get_built_in_model_type, - get_model_file_type, ) -from iotdb.ainode.core.model.uri_utils import get_model_register_strategy +from iotdb.ainode.core.model.utils import * from iotdb.ainode.core.util.lock import ModelLockPool from iotdb.thrift.ainode.ttypes import TShowModelsReq, TShowModelsResp from iotdb.thrift.common.ttypes import TSStatus @@ -64,393 +41,536 @@ logger = Logger() -class ModelStorage(object): - def __init__(self): - self._model_dir = os.path.join( - os.getcwd(), AINodeDescriptor().get_config().get_ain_models_dir() - ) - if not os.path.exists(self._model_dir): - try: - os.makedirs(self._model_dir) - except PermissionError as e: - logger.error(e) - raise e - self._builtin_model_dir = os.path.join( - os.getcwd(), AINodeDescriptor().get_config().get_ain_builtin_models_dir() - ) - if not os.path.exists(self._builtin_model_dir): - try: - os.makedirs(self._builtin_model_dir) - except PermissionError as e: - logger.error(e) - raise e +class ModelStorage: + """Model storage class - unified management of model discovery and registration""" + + def __init__(self, models_dir: str): + self.models_dir = Path(models_dir) + # Unified storage: category -> {model_id -> ModelInfo} + self._models: Dict[str, Dict[str, ModelInfo]] = { + ModelCategory.BUILTIN.value: {}, + ModelCategory.USER_DEFINED.value: {}, + ModelCategory.FINETUNE.value: {}, + } + # Async download executor (using single-threaded executor because hf download interface is unstable with concurrent downloads) + self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) + # Thread lock pool for protecting concurrent access to model information self._lock_pool = ModelLockPool() - self._executor = concurrent.futures.ThreadPoolExecutor( - max_workers=1 - ) # TODO: Here we set the work_num=1 cause we found that the hf download interface is not stable for concurrent downloading. - self._model_info_map: Dict[str, ModelInfo] = {} - self._init_model_info_map() + self._initialize_directories() + + def _initialize_directories(self): + """Initialize directory structure and ensure __init__.py files exist""" + self.models_dir.mkdir(parents=True, exist_ok=True) + ensure_init_file(self.models_dir) + + for category in ModelCategory: + category_path = self.models_dir / category.value + category_path.mkdir(parents=True, exist_ok=True) + ensure_init_file(category_path) + + # ==================== Discovery Methods ==================== + + def discover_all(self) -> Dict[str, Dict[str, ModelInfo]]: + """Scan file system to discover all models""" + self._discover_category(ModelCategory.BUILTIN) + self._discover_category(ModelCategory.USER_DEFINED) + self._discover_category(ModelCategory.FINETUNE) + return self._models + + def _discover_category(self, category: ModelCategory): + """Discover all models in a category directory""" + category_path = self.models_dir / category.value + if not category_path.exists(): + return + + if category == ModelCategory.BUILTIN: + self._discover_builtin_models(category_path) + else: + # For finetune and user_defined, scan directories + for item in category_path.iterdir(): + if item.is_dir() and not item.name.startswith("__"): + relative_path = item.relative_to(category_path) + model_id = str(relative_path).replace("/", "_").replace("\\", "_") + self._process_model_directory(item, model_id, category) + + def _discover_builtin_models(self, category_path: Path): + # Register SKTIME models directly from map + for model_id in BUILTIN_SKTIME_MODEL_MAP.keys(): + with self._lock_pool.get_lock(model_id).write_lock(): + self._models[ModelCategory.BUILTIN.value][model_id] = ( + BUILTIN_SKTIME_MODEL_MAP[model_id] + ) + + # Process HuggingFace Transformers models + for model_id in BUILTIN_HF_TRANSFORMERS_MODEL_MAP.keys(): + model_dir = category_path / model_id + model_dir.mkdir(parents=True, exist_ok=True) + self._process_model_directory(model_dir, model_id, ModelCategory.BUILTIN) + + def _process_model_directory( + self, model_dir: Path, model_id: str, category: ModelCategory + ): + """Handling the discovery logic for a single model directory.""" + ensure_init_file(model_dir) + + config_path = model_dir / MODEL_CONFIG_FILE + weights_path = model_dir / MODEL_WEIGHTS_FILE + needs_download = not config_path.exists() or not weights_path.exists() + + if needs_download: + with self._lock_pool.get_lock(model_id).write_lock(): + model_info = ModelInfo( + model_id=model_id, + model_type="", # Read from config.json after download + category=category, + state=ModelStates.ACTIVATING, + path=str(model_dir), + auto_map=None, + _transformers_registered=False, + ) + self._models[category.value][model_id] = model_info - def _init_model_info_map(self): - """ - Initialize the model info map. - """ - # 1. initialize built-in and ready-to-use models - for model_id in BUILT_IN_MACHINE_LEARNING_MODEL_MAP: - self._model_info_map[model_id] = BUILT_IN_MACHINE_LEARNING_MODEL_MAP[ - model_id - ] - # 2. retrieve fine-tuned models from the built-in model directory - fine_tuned_models = self._retrieve_fine_tuned_models() - for model_id in fine_tuned_models: - self._model_info_map[model_id] = fine_tuned_models[model_id] - # 3. automatically downloading the weights of built-in LSTM models when necessary - for model_id in BUILT_IN_LTSM_MAP: - if model_id not in self._model_info_map: - self._model_info_map[model_id] = BUILT_IN_LTSM_MAP[model_id] future = self._executor.submit( - self._download_built_in_model_if_necessary, model_id + self._download_model_if_necessary, str(model_dir), model_id ) future.add_done_callback( - lambda f, mid=model_id: self._callback_model_download_result(f, mid) - ) - # 4. retrieve user-defined models from the model directory - user_defined_models = self._retrieve_user_defined_models() - for model_id in user_defined_models: - self._model_info_map[model_id] = user_defined_models[model_id] - - def _retrieve_fine_tuned_models(self): - """ - Retrieve fine-tuned models from the built-in model directory. - - Returns: - {"model_id": ModelInfo} - """ - result = {} - build_in_dirs = [ - d - for d in os.listdir(self._builtin_model_dir) - if os.path.isdir(os.path.join(self._builtin_model_dir, d)) - ] - for model_id in build_in_dirs: - config_file_path = os.path.join( - self._builtin_model_dir, model_id, MODEL_CONFIG_FILE_IN_JSON + lambda f, mid=model_id, cat=category: self._callback_model_download_result( + f, mid, cat + ) ) - if os.path.isfile(config_file_path): - with open(config_file_path, "r") as f: - model_config = json.load(f) - if "model_type" in model_config: - model_type = model_config["model_type"] - model_info = ModelInfo( - model_id=model_id, - model_type=model_type, - category=ModelCategory.FINE_TUNED, - state=ModelStates.ACTIVE, - ) - # Refactor the built-in model category - if "timer_xl" == model_id: - model_info.category = ModelCategory.BUILT_IN - if "sundial" == model_id: - model_info.category = ModelCategory.BUILT_IN - # Compatible patch with the codes in HuggingFace - if "timer" == model_type: - model_info.model_type = BuiltInModelType.TIMER_XL.value - if "sundial" == model_type: - model_info.model_type = BuiltInModelType.SUNDIAL.value - result[model_id] = model_info - return result - - def _download_built_in_model_if_necessary(self, model_id: str) -> bool: - """ - Download the built-in model if it is not already downloaded. + else: + config = load_model_config(config_path) + model_type = config.get("model_type", "") + auto_map = config.get("auto_map") + + with self._lock_pool.get_lock(model_id).write_lock(): + model_info = ModelInfo( + model_id=model_id, + model_type=model_type, + category=category, + state=ModelStates.ACTIVE, + path=str(model_dir), + auto_map=auto_map, + _transformers_registered=False, # Lazy registration + ) + self._models[category.value][model_id] = model_info - Args: - model_id (str): The ID of the model to download. + def _download_model_if_necessary(self, model_dir: str, model_id: str) -> bool: + """Returns: True if the model is existed or downloaded successfully, False otherwise.""" + if model_id in REPO_ID_MAP: + repo_id = REPO_ID_MAP[model_id] + else: + logger.error(f"Model {model_id} not found in REPO_ID_MAP") + return False - Return: - bool: True if the model is existed or downloaded successfully, False otherwise. - """ - with self._lock_pool.get_lock(model_id).write_lock(): - local_dir = os.path.join(self._builtin_model_dir, model_id) - return download_built_in_ltsm_from_hf_if_necessary( - get_built_in_model_type(self._model_info_map[model_id].model_type), - local_dir, - ) + weights_path = os.path.join(model_dir, MODEL_WEIGHTS_FILE) + config_path = os.path.join(model_dir, MODEL_CONFIG_FILE) - def _callback_model_download_result(self, future, model_id: str): - with self._lock_pool.get_lock(model_id).write_lock(): - if future.result(): - self._model_info_map[model_id].state = ModelStates.ACTIVE - logger.info( - f"The built-in model: {model_id} is active and ready to use." + if not os.path.exists(weights_path): + try: + hf_hub_download( + repo_id=repo_id, + filename=MODEL_WEIGHTS_FILE, + local_dir=model_dir, ) - else: - self._model_info_map[model_id].state = ModelStates.INACTIVE + except Exception as e: + logger.error(f"Failed to download model weights from HuggingFace: {e}") + return False - def _retrieve_user_defined_models(self): - """ - Retrieve user_defined models from the model directory. + if not os.path.exists(config_path): + try: + hf_hub_download( + repo_id=repo_id, + filename=MODEL_CONFIG_FILE, + local_dir=model_dir, + ) + except Exception as e: + logger.error(f"Failed to download model config from HuggingFace: {e}") + return False - Returns: - {"model_id": ModelInfo} - """ - result = {} - user_dirs = [ - d - for d in os.listdir(self._model_dir) - if os.path.isdir(os.path.join(self._model_dir, d)) and d != "weights" - ] - for model_id in user_dirs: - result[model_id] = ModelInfo( - model_id=model_id, - model_type="", - category=ModelCategory.USER_DEFINED, - state=ModelStates.ACTIVE, - ) - return result + return True - def register_model(self, model_id: str, uri: str): - """ - Args: - model_id: id of model to register - uri: network or local dir path of the model to register - Returns: - configs: TConfigs - attributes: str - """ + def _callback_model_download_result( + self, future, model_id: str, category: ModelCategory + ): + """Callback function for handling model download results""" with self._lock_pool.get_lock(model_id).write_lock(): - storage_path = os.path.join(self._model_dir, f"{model_id}") - # create storage dir if not exist - if not os.path.exists(storage_path): - os.makedirs(storage_path) - uri_type, parsed_uri, model_file_type = get_model_register_strategy(uri) - self._model_info_map[model_id] = ModelInfo( - model_id=model_id, - model_type="", - category=ModelCategory.USER_DEFINED, - state=ModelStates.LOADING, - ) try: - # TODO: The uri should be fetched asynchronously - configs, attributes = fetch_model_by_uri( - uri_type, parsed_uri, storage_path, model_file_type - ) - self._model_info_map[model_id].state = ModelStates.ACTIVE - return configs, attributes + if future.result(): + if model_id in self._models[category.value]: + model_info = self._models[category.value][model_id] + model_info.state = ModelStates.ACTIVE + config_path = os.path.join(model_info.path, MODEL_CONFIG_FILE) + if os.path.exists(config_path): + with open(config_path, "r", encoding="utf-8") as f: + config = json.load(f) + model_info.model_type = config.get("model_type", "") + model_info.auto_map = config.get("auto_map") + logger.info( + f"Model {model_id} downloaded successfully and is ready to use." + ) + else: + if model_id in self._models[category.value]: + self._models[category.value][ + model_id + ].state = ModelStates.INACTIVE + logger.warning(f"Failed to download model {model_id}.") except Exception as e: - logger.error(f"Failed to register model {model_id}: {e}") - self._model_info_map[model_id].state = ModelStates.INACTIVE - raise e + logger.error(f"Error in download callback for model {model_id}: {e}") + if model_id in self._models[category.value]: + self._models[category.value][model_id].state = ModelStates.INACTIVE - def delete_model(self, model_id: str) -> None: + # ==================== Registration Methods ==================== + + def register_model(self, model_id: str, uri: str) -> bool: """ - Args: - model_id: id of model to delete - Returns: - None + Supported URI formats: + - repo:// + - file:// """ - # check if the model is built-in - with self._lock_pool.get_lock(model_id).read_lock(): - if self._is_built_in(model_id): - raise BuiltInModelDeletionError(model_id) + uri_type = parse_uri_type(uri) + parsed_uri = get_parsed_uri(uri) - # delete the user-defined or fine-tuned model - with self._lock_pool.get_lock(model_id).write_lock(): - storage_path = os.path.join(self._model_dir, f"{model_id}") - if os.path.exists(storage_path): - shutil.rmtree(storage_path) - storage_path = os.path.join(self._builtin_model_dir, f"{model_id}") - if os.path.exists(storage_path): - shutil.rmtree(storage_path) - if model_id in self._model_info_map: - del self._model_info_map[model_id] - logger.info(f"Model {model_id} deleted successfully.") - - def _is_built_in(self, model_id: str) -> bool: - """ - Check if the model_id corresponds to a built-in model. + model_dir = Path(self.models_dir) / "user_defined" / model_id + model_dir.mkdir(parents=True, exist_ok=True) + ensure_init_file(model_dir) - Args: - model_id (str): The ID of the model. + if uri_type == UriType.REPO: + self._fetch_model_from_hf_repo(parsed_uri, str(model_dir)) + else: + self._fetch_model_from_local(os.path.expanduser(parsed_uri), str(model_dir)) - Returns: - bool: True if the model is built-in, False otherwise. - """ - return ( - model_id in self._model_info_map - and self._model_info_map[model_id].category == ModelCategory.BUILT_IN + config_path, _ = validate_model_files(model_dir) + config = load_model_config(config_path) + model_type = config.get("model_type", "") + auto_map = config.get("auto_map") + + with self._lock_pool.get_lock(model_id).write_lock(): + model_info = ModelInfo( + model_id=model_id, + model_type=model_type, + category=ModelCategory.USER_DEFINED, + state=ModelStates.ACTIVE, + path=str(model_dir), + auto_map=auto_map, + _transformers_registered=False, # Register later + ) + self._models[ModelCategory.USER_DEFINED.value][model_id] = model_info + + if auto_map: + # Transformers model: immediately register to Transformers auto-loading mechanism + success = self._register_transformers_model(model_info) + if success: + with self._lock_pool.get_lock(model_id).write_lock(): + model_info._transformers_registered = True + else: + with self._lock_pool.get_lock(model_id).write_lock(): + model_info.state = ModelStates.INACTIVE + logger.error(f"Failed to register Transformers model {model_id}") + return False + else: + # Other type models: only log + self._register_other_model(model_info) + + logger.info(f"Successfully registered model {model_id} from URI: {uri}") + return True + + def _fetch_model_from_hf_repo(self, repo_id: str, storage_path: str): + logger.info( + f"Downloading model from HuggingFace repository: {repo_id} -> {storage_path}" ) - def is_built_in_or_fine_tuned(self, model_id: str) -> bool: - """ - Check if the model_id corresponds to a built-in or fine-tuned model. + # Use snapshot_download to download entire repository (including config.json and model.safetensors) + try: + snapshot_download( + repo_id=repo_id, + local_dir=storage_path, + local_dir_use_symlinks=False, + ) + except Exception as e: + logger.error(f"Failed to download model from HuggingFace: {e}") + raise - Args: - model_id (str): The ID of the model. + def _fetch_model_from_local(self, source_path: str, storage_path: str): + logger.info(f"Copying model from local path: {source_path} -> {storage_path}") - Returns: - bool: True if the model is built-in or fine_tuned, False otherwise. - """ - return model_id in self._model_info_map and ( - self._model_info_map[model_id].category == ModelCategory.BUILT_IN - or self._model_info_map[model_id].category == ModelCategory.FINE_TUNED - ) + source_dir = Path(source_path) + if not source_dir.is_dir(): + raise ValueError( + f"Source path does not exist or is not a directory: {source_path}" + ) - def load_model( - self, model_id: str, inference_attrs: Dict[str, str], acceleration: bool - ) -> Callable: - """ - Load a model with automatic detection of .safetensors or .pt format + source_config = source_dir / MODEL_CONFIG_FILE + source_weights = source_dir / MODEL_WEIGHTS_FILE + if not source_config.exists(): + raise ValueError( + f"Config file missing in source directory: {source_config}" + ) + if not source_weights.exists(): + raise ValueError( + f"Weights file missing in source directory: {source_weights}" + ) - Returns: - model: The model instance corresponding to specific model_id + # Copy all files + storage_dir = Path(storage_path) + for file in source_dir.iterdir(): + if file.is_file(): + shutil.copy2(file, storage_dir / file.name) + + def _register_transformers_model(self, model_info: ModelInfo) -> bool: + """ + Register Transformers model to auto-loading mechanism (internal method) """ - with self._lock_pool.get_lock(model_id).read_lock(): - if self.is_built_in_or_fine_tuned(model_id): - model_dir = os.path.join(self._builtin_model_dir, f"{model_id}") - return fetch_built_in_model( - get_built_in_model_type(self._model_info_map[model_id].model_type), - model_dir, - inference_attrs, + auto_map = model_info.auto_map + if not auto_map: + return False + + auto_config_path = auto_map.get("AutoConfig") + auto_model_path = auto_map.get("AutoModelForCausalLM") + + try: + module_parent = str(Path(model_info.path).parent.absolute()) + with temporary_sys_path(module_parent): + config_class = import_class_from_path( + model_info.model_id, auto_config_path + ) + AutoConfig.register(model_info.model_type, config_class) + logger.info( + f"Registered AutoConfig: {model_info.model_type} -> {auto_config_path}" ) - else: - # load the user-defined model - model_dir = os.path.join(self._model_dir, f"{model_id}") - model_file_type = get_model_file_type(model_dir) - if model_file_type == ModelFileType.SAFETENSORS: - # TODO: Support this function - raise UnsupportedError("SAFETENSORS format") - else: - model_path = os.path.join(model_dir, MODEL_WEIGHTS_FILE_IN_PT) - - if not os.path.exists(model_path): - raise ModelNotExistError(model_path) - model = torch.jit.load(model_path) - if ( - isinstance(model, torch._dynamo.eval_frame.OptimizedModule) - or not acceleration - ): - return model - - try: - model = torch.compile(model) - except Exception as e: - logger.warning( - f"acceleration failed, fallback to normal mode: {str(e)}" - ) - return model - def save_model(self, model_id: str, model: nn.Module): - """ - Save the model using save_pretrained + model_class = import_class_from_path( + model_info.model_id, auto_model_path + ) + AutoModelForCausalLM.register(config_class, model_class) + logger.info( + f"Registered AutoModelForCausalLM: {config_class.__name__} -> {auto_model_path}" + ) + return True + except Exception as e: + logger.warning( + f"Failed to register Transformers model {model_info.model_id}: {e}. Model may still work via auto_map, but ensure module path is correct." + ) + return False + + def _register_other_model(self, model_info: ModelInfo): + """Register other type models (non-Transformers models)""" + logger.info( + f"Registered other type model: {model_info.model_id} ({model_info.model_type})" + ) + + def ensure_transformers_registered(self, model_id: str) -> "ModelInfo": + """ + Ensure Transformers model is registered (called for lazy registration) + This method uses locks to ensure thread safety. All check logic is within lock protection. Returns: - Whether saving succeeded + str: If None, registration failed, otherwise returns model path """ + # Use lock to protect entire check-execute process with self._lock_pool.get_lock(model_id).write_lock(): - if self.is_built_in_or_fine_tuned(model_id): - model_dir = os.path.join(self._builtin_model_dir, f"{model_id}") - model.save_pretrained(model_dir) - else: - # save the user-defined model - model_dir = os.path.join(self._model_dir, f"{model_id}") - os.makedirs(model_dir, exist_ok=True) - model_path = os.path.join(model_dir, MODEL_WEIGHTS_FILE_IN_PT) - try: - scripted_model = ( - model - if isinstance(model, torch.jit.ScriptModule) - else torch.jit.script(model) + # Directly access _models dictionary (avoid calling get_model_info which may cause deadlock) + model_info = None + for category_dict in self._models.values(): + if model_id in category_dict: + model_info = category_dict[model_id] + break + + if not model_info: + logger.warning(f"Model {model_id} does not exist, cannot register") + return None + + # If already registered, return directly + if model_info._transformers_registered: + return model_info + + # If no auto_map, not a Transformers model, mark as registered (avoid duplicate checks) + if ( + not model_info.auto_map + or model_id in BUILTIN_HF_TRANSFORMERS_MODEL_MAP.keys() + ): + model_info._transformers_registered = True + return model_info + + # Execute registration (under lock protection) + try: + success = self._register_transformers_model(model_info) + if success: + model_info._transformers_registered = True + logger.info( + f"Model {model_id} successfully registered to Transformers" ) - torch.jit.save(scripted_model, model_path) - except Exception as e: - logger.error(f"Failed to save scripted model: {e}") - - def get_ckpt_path(self, model_id: str) -> str: - """ - Get the checkpoint path for a given model ID. + return model_info + else: + model_info.state = ModelStates.INACTIVE + logger.error(f"Model {model_id} failed to register to Transformers") + return None - Args: - model_id (str): The ID of the model. + except Exception as e: + # Ensure state consistency in exception cases + model_info.state = ModelStates.INACTIVE + model_info._transformers_registered = False + logger.error( + f"Exception occurred while registering model {model_id} to Transformers: {e}" + ) + return None - Returns: - str: The path to the checkpoint file for the model. - """ - # Only support built-in models for now - return os.path.join(self._builtin_model_dir, f"{model_id}") + # ==================== Show and Delete Models ==================== def show_models(self, req: TShowModelsReq) -> TShowModelsResp: resp_status = TSStatus( code=TSStatusCode.SUCCESS_STATUS.value, message="Show models successfully", ) - if req.modelId: - if req.modelId in self._model_info_map: - model_info = self._model_info_map[req.modelId] - return TShowModelsResp( - status=resp_status, - modelIdList=[req.modelId], - modelTypeMap={req.modelId: model_info.model_type}, - categoryMap={req.modelId: model_info.category.value}, - stateMap={req.modelId: model_info.state.value}, - ) + + # Use global lock to protect entire dictionary structure + with self._lock_pool.get_lock("").read_lock(): + if req.modelId: + # Find specified model + model_info = None + for category_dict in self._models.values(): + if req.modelId in category_dict: + model_info = category_dict[req.modelId] + break + + if model_info: + return TShowModelsResp( + status=resp_status, + modelIdList=[req.modelId], + modelTypeMap={req.modelId: model_info.model_type}, + categoryMap={req.modelId: model_info.category.value}, + stateMap={req.modelId: model_info.state.value}, + ) + else: + return TShowModelsResp( + status=resp_status, + modelIdList=[], + modelTypeMap={}, + categoryMap={}, + stateMap={}, + ) else: + # Return all models + model_id_list = [] + model_type_map = {} + category_map = {} + state_map = {} + + for category_dict in self._models.values(): + for model_id, model_info in category_dict.items(): + model_id_list.append(model_id) + model_type_map[model_id] = model_info.model_type + category_map[model_id] = model_info.category.value + state_map[model_id] = model_info.state.value + return TShowModelsResp( status=resp_status, - modelIdList=[], - modelTypeMap={}, - categoryMap={}, - stateMap={}, + modelIdList=model_id_list, + modelTypeMap=model_type_map, + categoryMap=category_map, + stateMap=state_map, ) - return TShowModelsResp( - status=resp_status, - modelIdList=list(self._model_info_map.keys()), - modelTypeMap=dict( - (model_id, model_info.model_type) - for model_id, model_info in self._model_info_map.items() - ), - categoryMap=dict( - (model_id, model_info.category.value) - for model_id, model_info in self._model_info_map.items() - ), - stateMap=dict( - (model_id, model_info.state.value) - for model_id, model_info in self._model_info_map.items() - ), - ) - def register_built_in_model(self, model_info: ModelInfo): - with self._lock_pool.get_lock(model_info.model_id).write_lock(): - self._model_info_map[model_info.model_id] = model_info + def delete_model(self, model_id: str) -> None: + # Use write lock to protect entire deletion process + with self._lock_pool.get_lock(model_id).write_lock(): + model_info = None + category_value = None + for cat_value, category_dict in self._models.items(): + if model_id in category_dict: + model_info = category_dict[model_id] + category_value = cat_value + break + + if not model_info: + logger.warning(f"Model {model_id} does not exist, cannot delete") + return + + if model_info.category == ModelCategory.BUILTIN: + raise BuiltInModelDeletionError(model_id) - def get_model_info(self, model_id: str) -> ModelInfo: - with self._lock_pool.get_lock(model_id).read_lock(): - if model_id in self._model_info_map: - return self._model_info_map[model_id] - else: - raise ValueError(f"Model {model_id} does not exist.") + model_path = Path(model_info.path) + if model_path.exists(): + try: + shutil.rmtree(model_path) + logger.info(f"Deleted model directory: {model_path}") + except Exception as e: + logger.error(f"Failed to delete model directory {model_path}: {e}") + raise - def update_model_state(self, model_id: str, state: ModelStates): - with self._lock_pool.get_lock(model_id).write_lock(): - if model_id in self._model_info_map: - self._model_info_map[model_id].state = state - else: - raise ValueError(f"Model {model_id} does not exist.") + if category_value and model_id in self._models[category_value]: + del self._models[category_value][model_id] + logger.info(f"Model {model_id} has been removed from storage") + + return - def get_built_in_model_type(self, model_id: str) -> BuiltInModelType: + # ==================== Query Methods ==================== + + def get_model_info( + self, model_id: str, category: Optional[ModelCategory] = None + ) -> Optional[ModelInfo]: """ - Get the type of the model with the given model_id. + Get single model information - Args: - model_id (str): The ID of the model. + If category is specified, use model_id's lock + If category is not specified, need to traverse all dictionaries, use global lock + """ + if category: + # Category specified, only need to access specific dictionary, use model_id's lock + with self._lock_pool.get_lock(model_id).read_lock(): + return self._models[category.value].get(model_id) + else: + # Category not specified, need to traverse all dictionaries, use global lock + with self._lock_pool.get_lock("").read_lock(): + for category_dict in self._models.values(): + if model_id in category_dict: + return category_dict[model_id] + return None + + def get_model_infos( + self, category: Optional[ModelCategory] = None, model_type: Optional[str] = None + ) -> List[ModelInfo]: + """ + Get model information list - Returns: - str: The type of the model. + Note: Since we need to traverse all models, use a global lock to protect the entire dictionary structure + For single model access, using model_id-based lock would be more efficient """ - with self._lock_pool.get_lock(model_id).read_lock(): - if model_id in self._model_info_map: - return get_built_in_model_type( - self._model_info_map[model_id].model_type - ) + matching_models = [] + + # For traversal operations, we need to protect the entire dictionary structure + # Use a special lock (using empty string as key) to protect the entire dictionary + with self._lock_pool.get_lock("").read_lock(): + if category and model_type: + for model_info in self._models[category.value].values(): + if model_info.model_type == model_type: + matching_models.append(model_info) + return matching_models + elif category: + return list(self._models[category.value].values()) + elif model_type: + for category_dict in self._models.values(): + for model_info in category_dict.values(): + if model_info.model_type == model_type: + matching_models.append(model_info) + return matching_models else: - raise ValueError(f"Model {model_id} does not exist.") + for category_dict in self._models.values(): + matching_models.extend(category_dict.values()) + return matching_models + + def is_model_registered(self, model_id: str) -> bool: + """Check if model is registered (search in _models)""" + with self._lock_pool.get_lock("").read_lock(): + for category_dict in self._models.values(): + if model_id in category_dict: + return True + return False + + def get_registered_models(self) -> List[str]: + """Get list of all registered model IDs""" + with self._lock_pool.get_lock("").read_lock(): + model_ids = [] + for category_dict in self._models.values(): + model_ids.extend(category_dict.keys()) + return model_ids diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/strategy/__init__.py b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/__init__.py similarity index 100% rename from iotdb-core/ainode/iotdb/ainode/core/inference/strategy/__init__.py rename to iotdb-core/ainode/iotdb/ainode/core/model/sktime/__init__.py diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/arima/config.json b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/arima/config.json new file mode 100644 index 0000000000000..dcdc133529090 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/arima/config.json @@ -0,0 +1,22 @@ +{ + "model_type": "sktime", + "name": "ARIMA", + "predict_length": 1, + "order": [1, 0, 0], + "seasonal_order": [0, 0, 0, 0], + "method": "lbfgs", + "maxiter": 1, + "suppress_warnings": true, + "out_of_sample_size": 0, + "scoring": "mse", + "with_intercept": true, + "time_varying_regression": false, + "enforce_stationarity": true, + "enforce_invertibility": true, + "simple_differencing": false, + "measurement_error": false, + "mle_regression": true, + "hamilton_representation": false, + "concentrate_scale": false +} + diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/configuration_sktime.py b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/configuration_sktime.py new file mode 100644 index 0000000000000..bd780da3a73fb --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/configuration_sktime.py @@ -0,0 +1,379 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +""" +Sktime model configuration module - simplified version +""" + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Union + +from iotdb.ainode.core.exception import ( + BuiltInModelNotSupportError, + ListRangeException, + NumericalRangeException, + StringRangeException, + WrongAttributeTypeError, +) +from iotdb.ainode.core.log import Logger + +logger = Logger() + + +@dataclass +class AttributeConfig: + """Base class for attribute configuration""" + + name: str + default: Any + type: str # 'int', 'float', 'str', 'bool', 'list', 'tuple' + low: Union[int, float, None] = None + high: Union[int, float, None] = None + choices: List[str] = field(default_factory=list) + value_type: type = None # Element type for list and tuple + + def validate_value(self, value): + """Validate if the value meets the requirements""" + if self.type == "int": + if not isinstance(value, int): + raise WrongAttributeTypeError(self.name, "int") + if self.low is not None and self.high is not None: + if not (self.low <= value <= self.high): + raise NumericalRangeException(self.name, value, self.low, self.high) + elif self.type == "float": + if not isinstance(value, (int, float)): + raise WrongAttributeTypeError(self.name, "float") + value = float(value) + if self.low is not None and self.high is not None: + if not (self.low <= value <= self.high): + raise NumericalRangeException(self.name, value, self.low, self.high) + elif self.type == "str": + if not isinstance(value, str): + raise WrongAttributeTypeError(self.name, "str") + if self.choices and value not in self.choices: + raise StringRangeException(self.name, value, self.choices) + elif self.type == "bool": + if not isinstance(value, bool): + raise WrongAttributeTypeError(self.name, "bool") + elif self.type == "list": + if not isinstance(value, list): + raise WrongAttributeTypeError(self.name, "list") + for item in value: + if not isinstance(item, self.value_type): + raise WrongAttributeTypeError(self.name, self.value_type) + elif self.type == "tuple": + if not isinstance(value, tuple): + raise WrongAttributeTypeError(self.name, "tuple") + for item in value: + if not isinstance(item, self.value_type): + raise WrongAttributeTypeError(self.name, self.value_type) + return True + + def parse(self, string_value: str): + """Parse string value to corresponding type""" + if self.type == "int": + try: + return int(string_value) + except: + raise WrongAttributeTypeError(self.name, "int") + elif self.type == "float": + try: + return float(string_value) + except: + raise WrongAttributeTypeError(self.name, "float") + elif self.type == "str": + return string_value + elif self.type == "bool": + if string_value.lower() == "true": + return True + elif string_value.lower() == "false": + return False + else: + raise WrongAttributeTypeError(self.name, "bool") + elif self.type == "list": + try: + list_value = eval(string_value) + except: + raise WrongAttributeTypeError(self.name, "list") + if not isinstance(list_value, list): + raise WrongAttributeTypeError(self.name, "list") + for i in range(len(list_value)): + try: + list_value[i] = self.value_type(list_value[i]) + except: + raise ListRangeException( + self.name, list_value, str(self.value_type) + ) + return list_value + elif self.type == "tuple": + try: + tuple_value = eval(string_value) + except: + raise WrongAttributeTypeError(self.name, "tuple") + if not isinstance(tuple_value, tuple): + raise WrongAttributeTypeError(self.name, "tuple") + list_value = list(tuple_value) + for i in range(len(list_value)): + try: + list_value[i] = self.value_type(list_value[i]) + except: + raise ListRangeException( + self.name, list_value, str(self.value_type) + ) + return tuple(list_value) + + +# Model configuration definitions - using concise dictionary format +MODEL_CONFIGS = { + "NAIVE_FORECASTER": { + "predict_length": AttributeConfig("predict_length", 1, "int", 1, 5000), + "pipeline": AttributeConfig( + "pipeline", "last", "str", choices=["last", "mean"] + ), + "sp": AttributeConfig("sp", 1, "int", 1, 5000), + }, + "EXPONENTIAL_SMOOTHING": { + "predict_length": AttributeConfig("predict_length", 1, "int", 1, 5000), + "damped_trend": AttributeConfig("damped_trend", False, "bool"), + "initialization_method": AttributeConfig( + "initialization_method", + "estimated", + "str", + choices=["estimated", "heuristic", "legacy-heuristic", "known"], + ), + "optimized": AttributeConfig("optimized", True, "bool"), + "remove_bias": AttributeConfig("remove_bias", False, "bool"), + "use_brute": AttributeConfig("use_brute", False, "bool"), + }, + "ARIMA": { + "predict_length": AttributeConfig("predict_length", 1, "int", 1, 5000), + "order": AttributeConfig("order", (1, 0, 0), "tuple", value_type=int), + "seasonal_order": AttributeConfig( + "seasonal_order", (0, 0, 0, 0), "tuple", value_type=int + ), + "method": AttributeConfig( + "method", + "lbfgs", + "str", + choices=["lbfgs", "bfgs", "newton", "nm", "cg", "ncg", "powell"], + ), + "maxiter": AttributeConfig("maxiter", 1, "int", 1, 5000), + "suppress_warnings": AttributeConfig("suppress_warnings", True, "bool"), + "out_of_sample_size": AttributeConfig("out_of_sample_size", 0, "int", 0, 5000), + "scoring": AttributeConfig( + "scoring", + "mse", + "str", + choices=["mse", "mae", "rmse", "mape", "smape", "rmsle", "r2"], + ), + "with_intercept": AttributeConfig("with_intercept", True, "bool"), + "time_varying_regression": AttributeConfig( + "time_varying_regression", False, "bool" + ), + "enforce_stationarity": AttributeConfig("enforce_stationarity", True, "bool"), + "enforce_invertibility": AttributeConfig("enforce_invertibility", True, "bool"), + "simple_differencing": AttributeConfig("simple_differencing", False, "bool"), + "measurement_error": AttributeConfig("measurement_error", False, "bool"), + "mle_regression": AttributeConfig("mle_regression", True, "bool"), + "hamilton_representation": AttributeConfig( + "hamilton_representation", False, "bool" + ), + "concentrate_scale": AttributeConfig("concentrate_scale", False, "bool"), + }, + "STL_FORECASTER": { + "predict_length": AttributeConfig("predict_length", 1, "int", 1, 5000), + "sp": AttributeConfig("sp", 2, "int", 1, 5000), + "seasonal": AttributeConfig("seasonal", 7, "int", 1, 5000), + "seasonal_deg": AttributeConfig("seasonal_deg", 1, "int", 0, 5000), + "trend_deg": AttributeConfig("trend_deg", 1, "int", 0, 5000), + "low_pass_deg": AttributeConfig("low_pass_deg", 1, "int", 0, 5000), + "seasonal_jump": AttributeConfig("seasonal_jump", 1, "int", 0, 5000), + "trend_jump": AttributeConfig("trend_jump", 1, "int", 0, 5000), + "low_pass_jump": AttributeConfig("low_pass_jump", 1, "int", 0, 5000), + }, + "GAUSSIAN_HMM": { + "n_components": AttributeConfig("n_components", 1, "int", 1, 5000), + "covariance_type": AttributeConfig( + "covariance_type", + "diag", + "str", + choices=["spherical", "diag", "full", "tied"], + ), + "min_covar": AttributeConfig("min_covar", 1e-3, "float", -1e10, 1e10), + "startprob_prior": AttributeConfig( + "startprob_prior", 1.0, "float", -1e10, 1e10 + ), + "transmat_prior": AttributeConfig("transmat_prior", 1.0, "float", -1e10, 1e10), + "means_prior": AttributeConfig("means_prior", 0.0, "float", -1e10, 1e10), + "means_weight": AttributeConfig("means_weight", 0.0, "float", -1e10, 1e10), + "covars_prior": AttributeConfig("covars_prior", 1e-2, "float", -1e10, 1e10), + "covars_weight": AttributeConfig("covars_weight", 1.0, "float", -1e10, 1e10), + "algorithm": AttributeConfig( + "algorithm", "viterbi", "str", choices=["viterbi", "map"] + ), + "n_iter": AttributeConfig("n_iter", 10, "int", 1, 5000), + "tol": AttributeConfig("tol", 1e-2, "float", -1e10, 1e10), + "params": AttributeConfig("params", "stmc", "str", choices=["stmc", "stm"]), + "init_params": AttributeConfig( + "init_params", "stmc", "str", choices=["stmc", "stm"] + ), + "implementation": AttributeConfig( + "implementation", "log", "str", choices=["log", "scaling"] + ), + }, + "GMM_HMM": { + "n_components": AttributeConfig("n_components", 1, "int", 1, 5000), + "n_mix": AttributeConfig("n_mix", 1, "int", 1, 5000), + "min_covar": AttributeConfig("min_covar", 1e-3, "float", -1e10, 1e10), + "startprob_prior": AttributeConfig( + "startprob_prior", 1.0, "float", -1e10, 1e10 + ), + "transmat_prior": AttributeConfig("transmat_prior", 1.0, "float", -1e10, 1e10), + "weights_prior": AttributeConfig("weights_prior", 1.0, "float", -1e10, 1e10), + "means_prior": AttributeConfig("means_prior", 0.0, "float", -1e10, 1e10), + "means_weight": AttributeConfig("means_weight", 0.0, "float", -1e10, 1e10), + "algorithm": AttributeConfig( + "algorithm", "viterbi", "str", choices=["viterbi", "map"] + ), + "covariance_type": AttributeConfig( + "covariance_type", + "diag", + "str", + choices=["sperical", "diag", "full", "tied"], + ), + "n_iter": AttributeConfig("n_iter", 10, "int", 1, 5000), + "tol": AttributeConfig("tol", 1e-2, "float", -1e10, 1e10), + "init_params": AttributeConfig( + "init_params", + "stmcw", + "str", + choices=[ + "s", + "t", + "m", + "c", + "w", + "st", + "sm", + "sc", + "sw", + "tm", + "tc", + "tw", + "mc", + "mw", + "cw", + "stm", + "stc", + "stw", + "smc", + "smw", + "scw", + "tmc", + "tmw", + "tcw", + "mcw", + "stmc", + "stmw", + "stcw", + "smcw", + "tmcw", + "stmcw", + ], + ), + "params": AttributeConfig( + "params", + "stmcw", + "str", + choices=[ + "s", + "t", + "m", + "c", + "w", + "st", + "sm", + "sc", + "sw", + "tm", + "tc", + "tw", + "mc", + "mw", + "cw", + "stm", + "stc", + "stw", + "smc", + "smw", + "scw", + "tmc", + "tmw", + "tcw", + "mcw", + "stmc", + "stmw", + "stcw", + "smcw", + "tmcw", + "stmcw", + ], + ), + "implementation": AttributeConfig( + "implementation", "log", "str", choices=["log", "scaling"] + ), + }, + "STRAY": { + "alpha": AttributeConfig("alpha", 0.01, "float", -1e10, 1e10), + "k": AttributeConfig("k", 10, "int", 1, 5000), + "knn_algorithm": AttributeConfig( + "knn_algorithm", + "brute", + "str", + choices=["brute", "kd_tree", "ball_tree", "auto"], + ), + "p": AttributeConfig("p", 0.5, "float", -1e10, 1e10), + "size_threshold": AttributeConfig("size_threshold", 50, "int", 1, 5000), + "outlier_tail": AttributeConfig( + "outlier_tail", "max", "str", choices=["min", "max"] + ), + }, +} + + +def get_attributes(model_id: str) -> Dict[str, AttributeConfig]: + """Get attribute configuration for Sktime model""" + model_id = "EXPONENTIAL_SMOOTHING" if model_id == "HOLTWINTERS" else model_id + if model_id not in MODEL_CONFIGS: + raise BuiltInModelNotSupportError(model_id) + return MODEL_CONFIGS[model_id] + + +def update_attribute( + input_attributes: Dict[str, str], attribute_map: Dict[str, AttributeConfig] +) -> Dict[str, Any]: + """Update Sktime model attributes using input attributes""" + attributes = {} + for name, config in attribute_map.items(): + if name in input_attributes: + value = config.parse(input_attributes[name]) + config.validate_value(value) + attributes[name] = value + else: + attributes[name] = config.default + return attributes diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/exponential_smoothing/config.json b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/exponential_smoothing/config.json new file mode 100644 index 0000000000000..d6002fb26e87a --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/exponential_smoothing/config.json @@ -0,0 +1,11 @@ +{ + "model_type": "sktime", + "name": "ExponentialSmoothing", + "predict_length": 1, + "damped_trend": false, + "initialization_method": "estimated", + "optimized": true, + "remove_bias": false, + "use_brute": false +} + diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/gaussian_hmm/config.json b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/gaussian_hmm/config.json new file mode 100644 index 0000000000000..3392e1c0b57c8 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/gaussian_hmm/config.json @@ -0,0 +1,20 @@ +{ + "model_type": "sktime", + "name": "GaussianHMM", + "n_components": 1, + "covariance_type": "diag", + "min_covar": 0.001, + "startprob_prior": 1.0, + "transmat_prior": 1.0, + "means_prior": 0.0, + "means_weight": 0.0, + "covars_prior": 0.01, + "covars_weight": 1.0, + "algorithm": "viterbi", + "n_iter": 10, + "tol": 0.01, + "params": "stmc", + "init_params": "stmc", + "implementation": "log" +} + diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/gmm_hmm/config.json b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/gmm_hmm/config.json new file mode 100644 index 0000000000000..235f8ae642da4 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/gmm_hmm/config.json @@ -0,0 +1,20 @@ +{ + "model_type": "sktime", + "name": "GMMHMM", + "n_components": 1, + "n_mix": 1, + "min_covar": 0.001, + "startprob_prior": 1.0, + "transmat_prior": 1.0, + "weights_prior": 1.0, + "means_prior": 0.0, + "means_weight": 0.0, + "algorithm": "viterbi", + "covariance_type": "diag", + "n_iter": 10, + "tol": 0.01, + "init_params": "stmcw", + "params": "stmcw", + "implementation": "log" +} + diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/modeling_sktime.py b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/modeling_sktime.py new file mode 100644 index 0000000000000..f272e3dda3579 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/modeling_sktime.py @@ -0,0 +1,178 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +""" +Sktime model implementation module - simplified version +""" + +from abc import abstractmethod +from typing import Any, Dict + +import numpy as np +from sklearn.preprocessing import MinMaxScaler +from sktime.detection.hmm_learn import GMMHMM, GaussianHMM +from sktime.detection.stray import STRAY +from sktime.forecasting.arima import ARIMA +from sktime.forecasting.exp_smoothing import ExponentialSmoothing +from sktime.forecasting.naive import NaiveForecaster +from sktime.forecasting.trend import STLForecaster + +from iotdb.ainode.core.exception import ( + BuiltInModelNotSupportError, + InferenceModelInternalError, +) +from iotdb.ainode.core.log import Logger + +from .configuration_sktime import get_attributes, update_attribute + +logger = Logger() + + +class SktimeModel: + """Base class for Sktime models""" + + def __init__(self, attributes: Dict[str, Any]): + self._attributes = attributes + self._model = None + + @abstractmethod + def generate(self, data): + """Execute generation/inference""" + raise NotImplementedError + + +class ForecastingModel(SktimeModel): + """Base class for forecasting models""" + + def generate(self, data): + """Execute forecasting""" + try: + predict_length = self._attributes["predict_length"] + self._model.fit(data) + output = self._model.predict(fh=range(predict_length)) + return np.array(output, dtype=np.float64) + except Exception as e: + raise InferenceModelInternalError(str(e)) + + +class DetectionModel(SktimeModel): + """Base class for detection models""" + + def generate(self, data): + """Execute detection""" + try: + self._model.fit(data) + output = self._model.predict(data) + return np.array(output, dtype=np.int32) + except Exception as e: + raise InferenceModelInternalError(str(e)) + + +class ArimaModel(ForecastingModel): + """ARIMA model""" + + def __init__(self, attributes: Dict[str, Any]): + super().__init__(attributes) + self._model = ARIMA( + **{k: v for k, v in attributes.items() if k != "predict_length"} + ) + + +class ExponentialSmoothingModel(ForecastingModel): + """Exponential smoothing model""" + + def __init__(self, attributes: Dict[str, Any]): + super().__init__(attributes) + self._model = ExponentialSmoothing( + **{k: v for k, v in attributes.items() if k != "predict_length"} + ) + + +class NaiveForecasterModel(ForecastingModel): + """Naive forecaster model""" + + def __init__(self, attributes: Dict[str, Any]): + super().__init__(attributes) + self._model = NaiveForecaster( + **{k: v for k, v in attributes.items() if k != "predict_length"} + ) + + +class STLForecasterModel(ForecastingModel): + """STL forecaster model""" + + def __init__(self, attributes: Dict[str, Any]): + super().__init__(attributes) + self._model = STLForecaster( + **{k: v for k, v in attributes.items() if k != "predict_length"} + ) + + +class GMMHMMModel(DetectionModel): + """GMM HMM model""" + + def __init__(self, attributes: Dict[str, Any]): + super().__init__(attributes) + self._model = GMMHMM(**attributes) + + +class GaussianHmmModel(DetectionModel): + """Gaussian HMM model""" + + def __init__(self, attributes: Dict[str, Any]): + super().__init__(attributes) + self._model = GaussianHMM(**attributes) + + +class STRAYModel(DetectionModel): + """STRAY anomaly detection model""" + + def __init__(self, attributes: Dict[str, Any]): + super().__init__(attributes) + self._model = STRAY(**attributes) + + def generate(self, data): + """STRAY requires special handling: normalize first""" + try: + data = MinMaxScaler().fit_transform(data) + output = self._model.fit_transform(data) + return np.array(output, dtype=np.int32) + except Exception as e: + raise InferenceModelInternalError(str(e)) + + +# Model factory mapping +_MODEL_FACTORY = { + "ARIMA": ArimaModel, + "EXPONENTIAL_SMOOTHING": ExponentialSmoothingModel, + "HOLTWINTERS": ExponentialSmoothingModel, # Use the same model class + "NAIVE_FORECASTER": NaiveForecasterModel, + "STL_FORECASTER": STLForecasterModel, + "GMM_HMM": GMMHMMModel, + "GAUSSIAN_HMM": GaussianHmmModel, + "STRAY": STRAYModel, +} + + +def create_sktime_model(model_id: str, **kwargs) -> SktimeModel: + """Create a Sktime model instance""" + attributes = update_attribute({**kwargs}, get_attributes(model_id.upper())) + model_class = _MODEL_FACTORY.get(model_id.upper()) + if model_class is None: + raise BuiltInModelNotSupportError(model_id) + return model_class(attributes) diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/naive_forecaster/config.json b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/naive_forecaster/config.json new file mode 100644 index 0000000000000..20d8c1ed32b5c --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/naive_forecaster/config.json @@ -0,0 +1,8 @@ +{ + "model_type": "sktime", + "name": "NaiveForecaster", + "predict_length": 1, + "strategy": "last", + "sp": 1 +} + diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/stl_forecaster/config.json b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/stl_forecaster/config.json new file mode 100644 index 0000000000000..1005f9d944e9d --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/stl_forecaster/config.json @@ -0,0 +1,14 @@ +{ + "model_type": "sktime", + "name": "STLForecaster", + "predict_length": 1, + "sp": 2, + "seasonal": 7, + "seasonal_deg": 1, + "trend_deg": 1, + "low_pass_deg": 1, + "seasonal_jump": 1, + "trend_jump": 1, + "low_pass_jump": 1 +} + diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/stray/config.json b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/stray/config.json new file mode 100644 index 0000000000000..64c64aa9e0514 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/stray/config.json @@ -0,0 +1,11 @@ +{ + "model_type": "sktime", + "name": "STRAY", + "alpha": 0.01, + "k": 10, + "knn_algorithm": "brute", + "p": 0.5, + "size_threshold": 50, + "outlier_tail": "max" +} + diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/uri_utils.py b/iotdb-core/ainode/iotdb/ainode/core/model/uri_utils.py deleted file mode 100644 index b2e759e00ce07..0000000000000 --- a/iotdb-core/ainode/iotdb/ainode/core/model/uri_utils.py +++ /dev/null @@ -1,137 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -import os -from enum import Enum -from typing import List - -from huggingface_hub import snapshot_download -from requests import Session -from requests.adapters import HTTPAdapter - -from iotdb.ainode.core.constant import ( - DEFAULT_CHUNK_SIZE, - DEFAULT_RECONNECT_TIMEOUT, - DEFAULT_RECONNECT_TIMES, -) -from iotdb.ainode.core.exception import UnsupportedError -from iotdb.ainode.core.log import Logger -from iotdb.ainode.core.model.model_enums import ModelFileType -from iotdb.ainode.core.model.model_info import get_model_file_type - -HTTP_PREFIX = "http://" -HTTPS_PREFIX = "https://" - -logger = Logger() - - -class UriType(Enum): - REPO = "repo" - FILE = "file" - HTTP = "http" - HTTPS = "https" - - @classmethod - def values(cls) -> List[str]: - return [item.value for item in cls] - - @staticmethod - def parse_uri_type(uri: str): - """ - Parse the URI type from the given string. - """ - if uri.startswith("repo://"): - return UriType.REPO - elif uri.startswith("file://"): - return UriType.FILE - elif uri.startswith("http://"): - return UriType.HTTP - elif uri.startswith("https://"): - return UriType.HTTPS - else: - raise ValueError(f"Invalid URI type for {uri}") - - -def get_model_register_strategy(uri: str): - """ - Determine the loading strategy for a model based on its URI/path. - - Args: - uri (str): The URI of the model to be registered. - - Returns: - uri_type (UriType): The type of the URI, which can be one of: REPO, FILE, HTTP, or HTTPS. - parsed_uri (str): Parsed uri to get related file - model_file_type (ModelFileType): The type of the model file, which can be one of: SAFETENSORS, PYTORCH, or UNKNOWN. - """ - - uri_type = UriType.parse_uri_type(uri) - if uri_type in (UriType.HTTP, UriType.HTTPS): - # TODO: support HTTP(S) URI - raise UnsupportedError("CREATE MODEL FROM HTTP(S) URI") - else: - parsed_uri = uri[7:] - if uri_type == UriType.FILE: - # handle ~ in URI - parsed_uri = os.path.expanduser(parsed_uri) - model_file_type = get_model_file_type(uri) - elif uri_type == UriType.REPO: - # Currently, UriType.REPO only corresponds to huggingface repository with SAFETENSORS format - model_file_type = ModelFileType.SAFETENSORS - else: - raise ValueError(f"Invalid URI type for {uri}") - return uri_type, parsed_uri, model_file_type - - -def download_snapshot_from_hf(repo_id: str, local_dir: str): - """ - Download everything from a HuggingFace repository. - - Args: - repo_id (str): The HuggingFace repository ID. - local_dir (str): The local directory to save the downloaded files. - """ - try: - snapshot_download( - repo_id=repo_id, - local_dir=local_dir, - ) - except Exception as e: - logger.error(f"Failed to download HuggingFace model {repo_id}: {e}") - raise e - - -def download_file(url: str, storage_path: str) -> None: - """ - Args: - url: url of file to download - storage_path: path to save the file - Returns: - None - """ - logger.info(f"Start Downloading file from {url} to {storage_path}") - session = Session() - adapter = HTTPAdapter(max_retries=DEFAULT_RECONNECT_TIMES) - session.mount(HTTP_PREFIX, adapter) - session.mount(HTTPS_PREFIX, adapter) - response = session.get(url, timeout=DEFAULT_RECONNECT_TIMEOUT, stream=True) - response.raise_for_status() - with open(storage_path, "wb") as file: - for chunk in response.iter_content(chunk_size=DEFAULT_CHUNK_SIZE): - if chunk: - file.write(chunk) - logger.info(f"Download file from {url} to {storage_path} success") diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/utils.py b/iotdb-core/ainode/iotdb/ainode/core/model/utils.py new file mode 100644 index 0000000000000..9da8486d3905e --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/utils.py @@ -0,0 +1,94 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import importlib +import json +import sys +from contextlib import contextmanager +from pathlib import Path +from typing import Dict, Tuple + +from iotdb.ainode.core.model.model_enums import ( + MODEL_CONFIG_FILE, + MODEL_WEIGHTS_FILE, + UriType, +) + + +def parse_uri_type(uri: str) -> UriType: + if uri.startswith("repo://"): + return UriType.REPO + elif uri.startswith("file://"): + return UriType.FILE + else: + raise ValueError( + f"Unsupported URI type: {uri}. Supported formats: repo:// or file://" + ) + + +def get_parsed_uri(uri: str) -> str: + return uri[7:] # Remove "repo://" or "file://" prefix + + +@contextmanager +def temporary_sys_path(path: str): + """Context manager for temporarily adding a path to sys.path""" + path_added = path not in sys.path + if path_added: + sys.path.insert(0, path) + try: + yield + finally: + if path_added and path in sys.path: + sys.path.remove(path) + + +def load_model_config(config_path: Path) -> Dict: + with open(config_path, "r", encoding="utf-8") as f: + return json.load(f) + + +def validate_model_files(model_dir: Path) -> Tuple[Path, Path]: + """Validate model files exist, return config and weights file paths""" + config_path = model_dir / MODEL_CONFIG_FILE + weights_path = model_dir / MODEL_WEIGHTS_FILE + + if not config_path.exists(): + raise ValueError(f"Model config file does not exist: {config_path}") + if not weights_path.exists(): + raise ValueError(f"Model weights file does not exist: {weights_path}") + + # Create __init__.py file to ensure model directory can be imported as a module + init_file = model_dir / "__init__.py" + if not init_file.exists(): + init_file.touch() + + return config_path, weights_path + + +def import_class_from_path(module_name, class_path: str): + file_name, class_name = class_path.rsplit(".", 1) + module = importlib.import_module(module_name + "." + file_name) + return getattr(module, class_name) + + +def ensure_init_file(path: Path): + """Ensure __init__.py file exists in the given path""" + init_file = path / "__init__.py" + if not init_file.exists(): + init_file.touch() diff --git a/iotdb-core/ainode/iotdb/ainode/core/rpc/handler.py b/iotdb-core/ainode/iotdb/ainode/core/rpc/handler.py index f01e1594f0698..de5de968a7e52 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/rpc/handler.py +++ b/iotdb-core/ainode/iotdb/ainode/core/rpc/handler.py @@ -20,7 +20,7 @@ from iotdb.ainode.core.log import Logger from iotdb.ainode.core.manager.cluster_manager import ClusterManager from iotdb.ainode.core.manager.inference_manager import InferenceManager -from iotdb.ainode.core.manager.model_manager import ModelManager +from iotdb.ainode.core.manager.model_manager import get_model_manager from iotdb.ainode.core.rpc.status import get_status from iotdb.ainode.core.util.gpu_mapping import get_available_devices from iotdb.thrift.ainode import IAINodeRPCService @@ -47,24 +47,10 @@ logger = Logger() -def _ensure_device_id_is_available(device_id_list: list[str]) -> TSStatus: - """ - Ensure that the device IDs in the provided list are available. - """ - available_devices = get_available_devices() - for device_id in device_id_list: - if device_id not in available_devices: - return TSStatus( - code=TSStatusCode.INVALID_URI_ERROR.value, - message=f"Device ID [{device_id}] is not available. You can use 'SHOW AI_DEVICES' to retrieve the available devices.", - ) - return TSStatus(code=TSStatusCode.SUCCESS_STATUS.value) - - class AINodeRPCServiceHandler(IAINodeRPCService.Iface): def __init__(self, ainode): self._ainode = ainode - self._model_manager = ModelManager() + self._model_manager = get_model_manager() self._inference_manager = InferenceManager() def stop(self) -> None: @@ -78,43 +64,55 @@ def stopAINode(self) -> TSStatus: def registerModel(self, req: TRegisterModelReq) -> TRegisterModelResp: return self._model_manager.register_model(req) - def loadModel(self, req: TLoadModelReq) -> TSStatus: - status = self._ensure_model_is_built_in_or_fine_tuned(req.existingModelId) - if status.code != TSStatusCode.SUCCESS_STATUS.value: - return status - status = _ensure_device_id_is_available(req.deviceIdList) - if status.code != TSStatusCode.SUCCESS_STATUS.value: - return status - return self._inference_manager.load_model(req) - - def unloadModel(self, req: TUnloadModelReq) -> TSStatus: - status = self._ensure_model_is_built_in_or_fine_tuned(req.modelId) - if status.code != TSStatusCode.SUCCESS_STATUS.value: - return status - status = _ensure_device_id_is_available(req.deviceIdList) - if status.code != TSStatusCode.SUCCESS_STATUS.value: - return status - return self._inference_manager.unload_model(req) - def deleteModel(self, req: TDeleteModelReq) -> TSStatus: return self._model_manager.delete_model(req) - def inference(self, req: TInferenceReq) -> TInferenceResp: - return self._inference_manager.inference(req) + def showModels(self, req: TShowModelsReq) -> TShowModelsResp: + return self._model_manager.show_models(req) - def forecast(self, req: TForecastReq) -> TSStatus: - return self._inference_manager.forecast(req) + def loadModel(self, req: TLoadModelReq) -> TSStatus: + if not self._model_manager.is_model_registered(req.existingModelId): + return TSStatus( + code=TSStatusCode.MODEL_NOT_FOUND_ERROR.value, + message=f"Model [{req.existingModelId}] is not supported. You can use 'SHOW MODELS' to retrieve the available models.", + ) - def getAIHeartbeat(self, req: TAIHeartbeatReq) -> TAIHeartbeatResp: - return ClusterManager.get_heart_beat(req) + available_devices = get_available_devices() + for device_id in req.deviceIdList: + if device_id not in available_devices: + return TSStatus( + code=TSStatusCode.INVALID_URI_ERROR.value, + message=f"Device ID [{device_id}] is not available. You can use 'SHOW AI_DEVICES' to retrieve the available devices.", + ) - def showModels(self, req: TShowModelsReq) -> TShowModelsResp: - return self._model_manager.show_models(req) + return self._inference_manager.load_model(req) + + def unloadModel(self, req: TUnloadModelReq) -> TSStatus: + if not self._model_manager.is_model_registered(req.modelId): + return TSStatus( + code=TSStatusCode.MODEL_NOT_FOUND_ERROR.value, + message=f"Model [{req.modelId}] is not supported. You can use 'SHOW MODELS' to retrieve the available models.", + ) + + available_devices = get_available_devices() + for device_id in req.deviceIdList: + if device_id not in available_devices: + return TSStatus( + code=TSStatusCode.INVALID_URI_ERROR.value, + message=f"Device ID [{device_id}] is not available. You can use 'SHOW AI_DEVICES' to retrieve the available devices.", + ) + + return self._inference_manager.unload_model(req) def showLoadedModels(self, req: TShowLoadedModelsReq) -> TShowLoadedModelsResp: - status = _ensure_device_id_is_available(req.deviceIdList) - if status.code != TSStatusCode.SUCCESS_STATUS.value: - return TShowLoadedModelsResp(status=status, deviceLoadedModelsMap={}) + available_devices = get_available_devices() + for device_id in req.deviceIdList: + if device_id not in available_devices: + status = TSStatus( + code=TSStatusCode.INVALID_URI_ERROR.value, + message=f"Device ID [{device_id}] is not available. You can use 'SHOW AI_DEVICES' to retrieve the available devices.", + ) + return TShowLoadedModelsResp(status=status, deviceLoadedModelsMap={}) return self._inference_manager.show_loaded_models(req) def showAIDevices(self) -> TShowAIDevicesResp: @@ -123,13 +121,14 @@ def showAIDevices(self) -> TShowAIDevicesResp: deviceIdList=get_available_devices(), ) + def inference(self, req: TInferenceReq) -> TInferenceResp: + return self._inference_manager.inference(req) + + def forecast(self, req: TForecastReq) -> TSStatus: + return self._inference_manager.forecast(req) + + def getAIHeartbeat(self, req: TAIHeartbeatReq) -> TAIHeartbeatResp: + return ClusterManager.get_heart_beat(req) + def createTrainingTask(self, req: TTrainingReq) -> TSStatus: pass - - def _ensure_model_is_built_in_or_fine_tuned(self, model_id: str) -> TSStatus: - if not self._model_manager.is_built_in_or_fine_tuned(model_id): - return TSStatus( - code=TSStatusCode.MODEL_NOT_FOUND_ERROR.value, - message=f"Model [{model_id}] is not a built-in or fine-tuned model. You can use 'SHOW MODELS' to retrieve the available models.", - ) - return TSStatus(code=TSStatusCode.SUCCESS_STATUS.value) diff --git a/iotdb-core/ainode/poetry.lock b/iotdb-core/ainode/poetry.lock index a2d8ce581f16e..4ba865f2be720 100644 --- a/iotdb-core/ainode/poetry.lock +++ b/iotdb-core/ainode/poetry.lock @@ -327,6 +327,8 @@ files = [ {file = "greenlet-3.2.4-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c2ca18a03a8cfb5b25bc1cbe20f3d9a4c80d8c3b13ba3df49ac3961af0b1018d"}, {file = "greenlet-3.2.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:9fe0a28a7b952a21e2c062cd5756d34354117796c6d9215a87f55e38d15402c5"}, {file = "greenlet-3.2.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:8854167e06950ca75b898b104b63cc646573aa5fef1353d4508ecdd1ee76254f"}, + {file = "greenlet-3.2.4-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:f47617f698838ba98f4ff4189aef02e7343952df3a615f847bb575c3feb177a7"}, + {file = "greenlet-3.2.4-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:af41be48a4f60429d5cad9d22175217805098a9ef7c40bfef44f7669fb9d74d8"}, {file = "greenlet-3.2.4-cp310-cp310-win_amd64.whl", hash = "sha256:73f49b5368b5359d04e18d15828eecc1806033db5233397748f4ca813ff1056c"}, {file = "greenlet-3.2.4-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:96378df1de302bc38e99c3a9aa311967b7dc80ced1dcc6f171e99842987882a2"}, {file = "greenlet-3.2.4-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:1ee8fae0519a337f2329cb78bd7a8e128ec0f881073d43f023c7b8d4831d5246"}, @@ -336,6 +338,8 @@ files = [ {file = "greenlet-3.2.4-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2523e5246274f54fdadbce8494458a2ebdcdbc7b802318466ac5606d3cded1f8"}, {file = "greenlet-3.2.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:1987de92fec508535687fb807a5cea1560f6196285a4cde35c100b8cd632cc52"}, {file = "greenlet-3.2.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:55e9c5affaa6775e2c6b67659f3a71684de4c549b3dd9afca3bc773533d284fa"}, + {file = "greenlet-3.2.4-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:c9c6de1940a7d828635fbd254d69db79e54619f165ee7ce32fda763a9cb6a58c"}, + {file = "greenlet-3.2.4-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:03c5136e7be905045160b1b9fdca93dd6727b180feeafda6818e6496434ed8c5"}, {file = "greenlet-3.2.4-cp311-cp311-win_amd64.whl", hash = "sha256:9c40adce87eaa9ddb593ccb0fa6a07caf34015a29bf8d344811665b573138db9"}, {file = "greenlet-3.2.4-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:3b67ca49f54cede0186854a008109d6ee71f66bd57bb36abd6d0a0267b540cdd"}, {file = "greenlet-3.2.4-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:ddf9164e7a5b08e9d22511526865780a576f19ddd00d62f8a665949327fde8bb"}, @@ -345,6 +349,8 @@ files = [ {file = "greenlet-3.2.4-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:3b3812d8d0c9579967815af437d96623f45c0f2ae5f04e366de62a12d83a8fb0"}, {file = "greenlet-3.2.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:abbf57b5a870d30c4675928c37278493044d7c14378350b3aa5d484fa65575f0"}, {file = "greenlet-3.2.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:20fb936b4652b6e307b8f347665e2c615540d4b42b3b4c8a321d8286da7e520f"}, + {file = "greenlet-3.2.4-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:ee7a6ec486883397d70eec05059353b8e83eca9168b9f3f9a361971e77e0bcd0"}, + {file = "greenlet-3.2.4-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:326d234cbf337c9c3def0676412eb7040a35a768efc92504b947b3e9cfc7543d"}, {file = "greenlet-3.2.4-cp312-cp312-win_amd64.whl", hash = "sha256:a7d4e128405eea3814a12cc2605e0e6aedb4035bf32697f72deca74de4105e02"}, {file = "greenlet-3.2.4-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:1a921e542453fe531144e91e1feedf12e07351b1cf6c9e8a3325ea600a715a31"}, {file = "greenlet-3.2.4-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:cd3c8e693bff0fff6ba55f140bf390fa92c994083f838fece0f63be121334945"}, @@ -354,6 +360,8 @@ files = [ {file = "greenlet-3.2.4-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:23768528f2911bcd7e475210822ffb5254ed10d71f4028387e5a99b4c6699671"}, {file = "greenlet-3.2.4-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:00fadb3fedccc447f517ee0d3fd8fe49eae949e1cd0f6a611818f4f6fb7dc83b"}, {file = "greenlet-3.2.4-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:d25c5091190f2dc0eaa3f950252122edbbadbb682aa7b1ef2f8af0f8c0afefae"}, + {file = "greenlet-3.2.4-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:6e343822feb58ac4d0a1211bd9399de2b3a04963ddeec21530fc426cc121f19b"}, + {file = "greenlet-3.2.4-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:ca7f6f1f2649b89ce02f6f229d7c19f680a6238af656f61e0115b24857917929"}, {file = "greenlet-3.2.4-cp313-cp313-win_amd64.whl", hash = "sha256:554b03b6e73aaabec3745364d6239e9e012d64c68ccd0b8430c64ccc14939a8b"}, {file = "greenlet-3.2.4-cp314-cp314-macosx_11_0_universal2.whl", hash = "sha256:49a30d5fda2507ae77be16479bdb62a660fa51b1eb4928b524975b3bde77b3c0"}, {file = "greenlet-3.2.4-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:299fd615cd8fc86267b47597123e3f43ad79c9d8a22bebdce535e53550763e2f"}, @@ -361,6 +369,8 @@ files = [ {file = "greenlet-3.2.4-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:b4a1870c51720687af7fa3e7cda6d08d801dae660f75a76f3845b642b4da6ee1"}, {file = "greenlet-3.2.4-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:061dc4cf2c34852b052a8620d40f36324554bc192be474b9e9770e8c042fd735"}, {file = "greenlet-3.2.4-cp314-cp314-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:44358b9bf66c8576a9f57a590d5f5d6e72fa4228b763d0e43fee6d3b06d3a337"}, + {file = "greenlet-3.2.4-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:2917bdf657f5859fbf3386b12d68ede4cf1f04c90c3a6bc1f013dd68a22e2269"}, + {file = "greenlet-3.2.4-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:015d48959d4add5d6c9f6c5210ee3803a830dce46356e3bc326d6776bde54681"}, {file = "greenlet-3.2.4-cp314-cp314-win_amd64.whl", hash = "sha256:e37ab26028f12dbb0ff65f29a8d3d44a765c61e729647bf2ddfbbed621726f01"}, {file = "greenlet-3.2.4-cp39-cp39-macosx_11_0_universal2.whl", hash = "sha256:b6a7c19cf0d2742d0809a4c05975db036fdff50cd294a93632d6a310bf9ac02c"}, {file = "greenlet-3.2.4-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:27890167f55d2387576d1f41d9487ef171849ea0359ce1510ca6e06c8bece11d"}, @@ -370,6 +380,8 @@ files = [ {file = "greenlet-3.2.4-cp39-cp39-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c9913f1a30e4526f432991f89ae263459b1c64d1608c0d22a5c79c287b3c70df"}, {file = "greenlet-3.2.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:b90654e092f928f110e0007f572007c9727b5265f7632c2fa7415b4689351594"}, {file = "greenlet-3.2.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:81701fd84f26330f0d5f4944d4e92e61afe6319dcd9775e39396e39d7c3e5f98"}, + {file = "greenlet-3.2.4-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:28a3c6b7cd72a96f61b0e4b2a36f681025b60ae4779cc73c1535eb5f29560b10"}, + {file = "greenlet-3.2.4-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:52206cd642670b0b320a1fd1cbfd95bca0e043179c1d8a045f2c6109dfe973be"}, {file = "greenlet-3.2.4-cp39-cp39-win32.whl", hash = "sha256:65458b409c1ed459ea899e939f0e1cdb14f58dbc803f2f93c5eab5694d32671b"}, {file = "greenlet-3.2.4-cp39-cp39-win_amd64.whl", hash = "sha256:d2e685ade4dafd447ede19c31277a224a239a0a1a4eca4e6390efedf20260cfb"}, {file = "greenlet-3.2.4.tar.gz", hash = "sha256:0dca0d95ff849f9a364385f36ab49f50065d76964944638be9691e1832e9f86d"}, @@ -1950,6 +1962,12 @@ files = [ {file = "statsmodels-0.14.5-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:5a085d47c8ef5387279a991633883d0e700de2b0acc812d7032d165888627bef"}, {file = "statsmodels-0.14.5-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:9f866b2ebb2904b47c342d00def83c526ef2eb1df6a9a3c94ba5fe63d0005aec"}, {file = "statsmodels-0.14.5-cp313-cp313-win_amd64.whl", hash = "sha256:2a06bca03b7a492f88c8106103ab75f1a5ced25de90103a89f3a287518017939"}, + {file = "statsmodels-0.14.5-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:07c4dad25bbb15864a31b4917a820f6d104bdc24e5ddadcda59027390c3bed9e"}, + {file = "statsmodels-0.14.5-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:babb067c852e966c2c933b79dbb5d0240919d861941a2ef6c0e13321c255528d"}, + {file = "statsmodels-0.14.5-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:110194b137286173cc676d7bad0119a197778de6478fc6cbdc3b33571165ac1e"}, + {file = "statsmodels-0.14.5-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9c8a9c384a60c80731b278e7fd18764364c8817f4995b13a175d636f967823d1"}, + {file = "statsmodels-0.14.5-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:557df3a870a57248df744fdfcc444ecbc5bdbf1c042b8a8b5d8e3e797830dc2a"}, + {file = "statsmodels-0.14.5-cp314-cp314-win_amd64.whl", hash = "sha256:95af7a9c4689d514f4341478b891f867766f3da297f514b8c4adf08f4fa61d03"}, {file = "statsmodels-0.14.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:b23b8f646dd78ef5e8d775d879208f8dc0a73418b41c16acac37361ff9ab7738"}, {file = "statsmodels-0.14.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4e5e26b21d2920905764fb0860957d08b5ba2fae4466ef41b1f7c53ecf9fc7fa"}, {file = "statsmodels-0.14.5-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4a060c7e0841c549c8ce2825fd6687e6757e305d9c11c9a73f6c5a0ce849bb69"},