db_org: str
dme_port: Optional[str] = None
source_name: Optional[str] = None
- measured_obj_class: Optional[str] = None
\ No newline at end of file
+ measured_obj_class: Optional[str] = None
+
+class ModelRegistrationIntent(BaseModel):
+ model_name: str
+ model_version: str
+ description: str
+ author: str
+ owner: Optional[str] = ""
+ input_data_type: str
+ output_data_type: str
+ model_location: Optional[str] = None
+ artifact_version: Optional[str] = None
\ No newline at end of file
import dspy
from trainingmgr.common.trainingmgr_config import TrainingMgrConfig
from trainingmgr.common.exceptions_utls import TMException
-from trainingmgr.schemas.agent_schema import FeatureGroupIntent
+from trainingmgr.schemas.agent_schema import FeatureGroupIntent, ModelRegistrationIntent
CONFIG = TrainingMgrConfig()
LOGGER = CONFIG.logger
except Exception as err:
raise TMException(f"Error creating feature group: {str(err)}")
+@dspy.Tool
+def register_model(
+ model_name: str,
+ model_version: str,
+ description: str,
+ author: str,
+ owner: str = "",
+ input_data_type: str = "",
+ output_data_type: str = "",
+ model_location: str = "",
+ artifact_version: str = ""
+) -> str:
+ """Register a model in the Model Management Service (MME)."""
+ try:
+ mme_ip = CONFIG.model_management_service_ip
+ mme_port = CONFIG.model_management_service_port
+ if not mme_ip or not mme_port:
+ raise TMException("Model management service IP/Port not configured")
+
+ obj = ModelRegistrationIntent(
+ model_name=model_name,
+ model_version=model_version,
+ description=description,
+ author=author,
+ owner=owner or "",
+ input_data_type=input_data_type or "",
+ output_data_type=output_data_type or "",
+ model_location=model_location or None,
+ artifact_version=(artifact_version or None),
+ )
+
+ payload = {
+ "modelId": {
+ "modelName": obj.model_name,
+ "modelVersion": obj.model_version,
+ },
+ "description": obj.description,
+ "modelInformation": {
+ "metadata": {
+ "author": obj.author,
+ "owner": obj.owner or "",
+ },
+ "inputDataType": obj.input_data_type,
+ "outputDataType": obj.output_data_type,
+ },
+ }
+ if obj.model_location:
+ payload["modelLocation"] = obj.model_location
+ if obj.artifact_version:
+ payload["modelId"]["artifactVersion"] = obj.artifact_version
+
+ url = f"http://{mme_ip}:{mme_port}/ai-ml-model-registration/v1/model-registrations"
+ response = requests.post(url, json=payload, timeout=15)
+ response.raise_for_status()
+ return f"Model '{model_name}' version '{model_version}' registered (status={response.status_code})."
+ except Exception as err:
+ raise TMException(f"Error registering model: {str(err)}")
# Define the agent signature
class AgentSignature(dspy.Signature):
# Agent configuration
self._agent = dspy.ReAct(
AgentSignature,
- tools=[create_feature_group],
+ tools=[create_feature_group, register_model],
max_iters=6
)