diff --git a/impc_etl/jobs/load/impc_kg/publications_mapper.py b/impc_etl/jobs/load/impc_kg/publications_mapper.py index edcc553f..c4c8630d 100644 --- a/impc_etl/jobs/load/impc_kg/publications_mapper.py +++ b/impc_etl/jobs/load/impc_kg/publications_mapper.py @@ -1,81 +1,76 @@ -import luigi -from impc_etl.jobs.load.impc_bulk_api.impc_api_mapper import to_camel_case -from luigi.contrib.spark import PySparkTask -from pyspark import SparkContext -from pyspark.sql import SparkSession -from pyspark.sql.functions import col +""" +Module to generate the publications data as JSON for the KG. +""" +import logging +import textwrap -from impc_etl.jobs.load.impc_kg.impc_kg_helper import add_unique_id, map_unique_ids -from impc_etl.workflow.config import ImpcConfig +from airflow.sdk import Variable, asset +from impc_etl.utils.airflow import create_input_asset, create_output_asset +from impc_etl.utils.spark import with_spark_publications_mongo_session -class ImpcKgPublicationsMapper(PySparkTask): - """ - PySpark Task class to parse GenTar Product report data. - """ +task_logger = logging.getLogger("airflow.task") +dr_tag = Variable.get("data_release_tag") - #: Name of the Spark task - name: str = "ImpcKgPublicationsMapper" +# The asset has no dependencies on files as the data used to create it +# is extracted from the publications MongoDB, but to ensure it runs while +# the KG is being built a dependency has been added on +# the output of the KG procedure_mapper. +procedure_json_path_asset = create_input_asset("output/impc_kg/procedure_json") - #: Path of the output directory where the new parquet file will be generated. - output_path: luigi.Parameter = luigi.Parameter() +publications_output_asset = create_output_asset("/impc_kg/publications_json") - def requires(self): - return [] - - def output(self): - """ - Returns the full parquet path as an output for the Luigi Task - (e.g. impc/dr15.2/parquet/product_report_parquet) - """ - return ImpcConfig().get_target(f"{self.output_path}/impc_kg/publications_json") - - def app_options(self): - """ - Generates the options pass to the PySpark job - """ - return [ - self.output().path, - ] - - def main(self, sc: SparkContext, *args): +@asset.multi( + schedule=[procedure_json_path_asset], + outlets=[publications_output_asset], + dag_id=f"{dr_tag}_impc_kg_publications_mapper", + description=textwrap.dedent( """ - Takes in a SparkContext and the list of arguments generated by `app_options` and executes the PySpark job. + PySpark task to create the publications Knowledge Graph JSON files + based on the data in the production publications MongoDB. """ - spark = SparkSession(sc) - - # Parsing app options - output_path = args[0] - - publications_df = spark.read.format("mongodb").load() - - publications_df = publications_df.where(col("status") == "reviewed").select( - "title", - "authorString", - "consortiumPaper", - "doi", - col("firstPublicationDate").alias("publicationDate"), - col("journalInfo.journal.title").alias("journalTitle"), - col("alleles.acc").alias("mgiAlleleAccessionIds"), - col("pmid").alias("pmId"), - "abstractText", - "meshHeadingList", - "grantsList", - ) - - publications_df = add_unique_id( - publications_df, - "publication_id", - ["pmId"], - ) - - publications_df = map_unique_ids( - publications_df, "alleles", "mgiAlleleAccessionIds" - ) - - publications_df = publications_df.drop("mgiAlleleAccessionIds") - - publications_df.coalesce(1).write.json( - output_path, mode="overwrite", compression="gzip" - ) + ), + tags=["impc_kg"], +) +@with_spark_publications_mongo_session +def impc_kg_publications_mapper(): + + from impc_etl.jobs.load.impc_kg.impc_kg_helper import add_unique_id, map_unique_ids + + from pyspark.sql import SparkSession + from pyspark.sql.functions import col + + spark = SparkSession.builder.getOrCreate() + + publications_df = spark.read.format("mongodb").option("collection", "references").load() + + publications_df = publications_df.where(col("status") == "reviewed").select( + "title", + "authorString", + "consortiumPaper", + "doi", + col("firstPublicationDate").alias("publicationDate"), + col("journalInfo.journal.title").alias("journalTitle"), + col("alleles.acc").alias("mgiAlleleAccessionIds"), + col("pmid").alias("pmId"), + "abstractText", + "meshHeadingList", + "grantsList", + ) + + publications_df = add_unique_id( + publications_df, + "publication_id", + ["pmId"], + ) + + publications_df = map_unique_ids( + publications_df, "alleles", "mgiAlleleAccessionIds" + ) + + publications_df = publications_df.drop("mgiAlleleAccessionIds") + + publications_df.coalesce(1).write.json( + publications_output_asset.uri, mode="overwrite", compression="gzip" + )