# OpenKiwi: Open-Source Machine Translation Quality Estimation
# Copyright (C) 2019 Unbabel <openkiwi@unbabel.com>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
#
import logging
import threading
import uuid
logger = logging.getLogger(__name__)
[docs]class TrackingLogger:
[docs] class ActiveRun:
def __init__(self, run_uuid, experiment_id):
self.run_uuid = run_uuid
self.experiment_name = experiment_id
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
return exc_type is None
def __init__(self):
self._experiment_id = None
self._experiment_name = None
self._active_run_uuids = []
[docs] def start_nested_run(self, run_name=None):
return self.configure(
run_uuid=run_name, experiment_name=None, nest_run=True
)
@property
def run_uuid(self):
return self._active_run_uuids[-1] if self._active_run_uuids else None
@property
def experiment_id(self):
return self._experiment_id
@property
def experiment_name(self):
return self._experiment_name
@property
def run_name(self):
return self._run_name
[docs] def should_log_artifacts(self):
return False
[docs] def get_tracking_uri(self):
return None
[docs] @staticmethod
def log_metric(key, value):
pass
[docs] @staticmethod
def log_param(key, value):
pass
[docs] @staticmethod
def log_artifact(local_path, artifact_path=None):
pass
[docs] @staticmethod
def log_artifacts(local_dir, artifact_path=None):
return None
[docs] @staticmethod
def get_artifact_uri():
return None
[docs] @staticmethod
def end_run():
pass
[docs]class MLflowLogger:
def __init__(self):
self.always_log_artifacts = False
self._experiment_name = None
self._run_name = None
[docs] def start_nested_run(self, run_name=None):
return mlflow.start_run(run_name=run_name, nested=True)
@property
def run_uuid(self):
return mlflow.tracking.fluent.active_run().info.run_uuid
@property
def experiment_id(self):
return mlflow.tracking.fluent.active_run().info.experiment_id
@property
def experiment_name(self):
# return MlflowClient().get_experiment(self.experiment_id).name
return self._experiment_name
[docs] def should_log_artifacts(self):
return self.always_log_artifacts or self._is_remote()
[docs] @staticmethod
def get_tracking_uri():
return mlflow.get_tracking_uri()
[docs] @staticmethod
def log_metric(key, value):
mlflow.log_metric(key, value)
[docs] @staticmethod
def log_param(key, value):
mlflow.log_param(key, value)
[docs] @staticmethod
def log_artifact(local_path, artifact_path=None):
t = threading.Thread(
target=mlflow.log_artifact,
args=(local_path,),
kwargs={'artifact_path': artifact_path},
daemon=True,
)
t.start()
[docs] @staticmethod
def log_artifacts(local_dir, artifact_path=None):
def send(dpath, e, path):
mlflow.log_artifacts(dpath, artifact_path=path)
e.set()
event = threading.Event()
t = threading.Thread(
target=send, args=(local_dir, event, artifact_path), daemon=True
)
t.start()
return event
[docs] @staticmethod
def get_artifact_uri():
return mlflow.get_artifact_uri()
[docs] @staticmethod
def end_run():
mlflow.end_run()
def _is_remote(self):
return not mlflow.tracking.utils._is_local_uri(
mlflow.get_tracking_uri()
)
@staticmethod
def _retrieve_mlflow_experiment_id(name, create=False):
experiment_id = None
if name:
existing_experiment = MlflowClient().get_experiment_by_name(name)
if existing_experiment:
experiment_id = existing_experiment.experiment_id
else:
if create:
experiment_id = mlflow.create_experiment(name)
else:
raise Exception(
'Experiment "{}" not found in {}'.format(
name, mlflow.get_tracking_uri()
)
)
return experiment_id
try:
import mlflow
from mlflow.tracking import MlflowClient
tracking_logger = MLflowLogger()
except ImportError:
tracking_logger = TrackingLogger()