Created
May 5, 2023 22:19
-
-
Save quiiver/4938b72a7fac2f2fbb75ec541d64bfac to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import dataclasses | |
from typing import Any, Dict, List, Optional, Union | |
from airflow import XComArg | |
from airflow.hooks.base import BaseHook | |
from dags.operators.gcp_container_operator import GKEPodOperator | |
class MerinoJob(GKEPodOperator): | |
def __init__(self, name: str, arguments: List[str], env_vars: Optional[Dict[str, Any]] = None, *args, **kwargs): | |
# if not isinstance(env_vars, XComArg): | |
# default_env_vars = {"MERINO_ENV": "production"} | |
# if env_vars is None: | |
# env_vars = {} | |
# default_env_vars.update(env_vars) | |
# env_vars = default_env_vars | |
if "task_id" not in kwargs: | |
kwargs["task_id"] = name | |
super(MerinoJob, self).__init__( | |
name=name, | |
image="mozilla/merino-py:latest", | |
project_id="moz-fx-data-airflow-gke-prod", | |
gcp_conn_id="google_cloud_airflow_gke", | |
cluster_name="workloads-prod-v1", | |
location="us-west1", | |
cmds=["python", "-m", "merino.jobs.cli"], | |
arguments=arguments, | |
env_vars=env_vars, | |
email=[ | |
"[email protected]", | |
], | |
*args, | |
**kwargs, | |
) | |
@dataclasses.dataclass | |
class MerinoJobTask: | |
name: str | |
arguments: List[str] | |
connections: List[str] | |
env_vars: Dict[str, str] | |
def build(self) -> "MerinoJobTask": | |
return MerinoJobTask( | |
name = "-".join(self.arguments[:2]), | |
arguments = self.arguments, | |
env_vars = {**{"MERINO_ENV": "production"}, **self.env_vars, **self._get_connections()}, | |
connections = [], | |
) | |
def _get_connections(self) -> Dict[str, str]: | |
env_vars = {} | |
for conn_id in self.connections: | |
env_vars.update(self._get_conn_details(conn_id)) | |
return env_vars | |
def _get_conn_details(self, conn_id: str) -> Dict: | |
if conn_id is "elasticsearch_prod" or conn_id is "elasticsearch_stage": | |
conn = BaseHook.get_connection(f"merino_{conn_id}") | |
return { | |
"MERINO_JOBS__WIKIPEDIA_INDEXER__ES_URL": str(conn.host), | |
"MERINO_JOBS__WIKIPEDIA_INDEXER__ES_API_KEY": str(conn.password), | |
} | |
else: | |
return {} | |
@dataclasses.dataclass | |
class MerinoJobParams: | |
""" | |
Dataclass for creating merino job definitions from dag parameters | |
{ | |
"task": { | |
"arguments": ["wikipedia-indexer", "..."] | |
"connections": ["elasticsearch_prod"] | |
"env_vars": {"FOO": "bar"} | |
}, | |
"dry_run": false, | |
} | |
""" | |
task: MerinoJobTask | |
dry_run: bool | |
def build_task(self) -> dict: | |
task = self.task.build() | |
return { | |
"dry_run": self.dry_run, | |
"arguments": task.arguments, | |
"env_vars": task.env_vars, | |
"name": task.name, | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import datetime | |
from typing import Dict, Any | |
from airflow.decorators import dag, task | |
from airflow.models.param import Param | |
from airflow.operators.python import get_current_context | |
from utils.merino import MerinoJob, MerinoJobParams | |
from utils.tags import Tag | |
doc_md = """ | |
# Merino Job Adhoc DAG | |
#### Use with caution | |
#### Some tips/notes: | |
* Always use dry run first. | |
""" | |
@task() | |
def generate_task() -> Dict[str, Any]: | |
ctx = get_current_context() | |
job_params = MerinoJobParams(**ctx.get("params", {})) | |
return job_params.build_task() | |
@dag( | |
dag_id="merino_jobs_adhoc", | |
schedule_interval=None, | |
doc_md=doc_md, | |
catchup=False, | |
start_date=datetime.datetime(2023, 5, 1), | |
dagrun_timeout=datetime.timedelta(days=1), | |
tags=[Tag.ImpactTier.tier_3, Tag.Triage.no_triage], | |
render_template_as_native_obj=True, | |
params={ | |
"task": Param("task", schema = { | |
"type": "object", | |
"default": {"arguments": [], "connections": []}, | |
"properties": { | |
"arguments": {"type": "array", "items": {"type": "string"}}, | |
"connections": {"type": "array", "items": {"type": "string"}}, | |
"env_vars": {"type": "object"}, | |
}, | |
"required": ["arguments"] | |
}), | |
"dry_run": Param(True, type="boolean"), | |
}, | |
) | |
def adhoc_dag(): | |
task_details = generate_task() | |
merino_job = MerinoJob( | |
task_id="adhoc_merino_job", | |
name=task_details["name"], | |
arguments=task_details["arguments"], | |
env_vars=task_details["env_vars"], | |
) | |
merino_job | |
dag = adhoc_dag() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment