diff --git a/sdk/python/feast/feature_store.py b/sdk/python/feast/feature_store.py index 75f406c8e2c..b143b41a595 100644 --- a/sdk/python/feast/feature_store.py +++ b/sdk/python/feast/feature_store.py @@ -1287,6 +1287,10 @@ def _apply_diffs( # Emit OpenLineage events for applied objects self._emit_openlineage_apply_diffs(registry_diff) + # Register feature views as Unity Catalog feature tables (if using + # the databricks_uc offline store) + self._register_uc_feature_tables_from_diffs(registry_diff) + # Emit MLflow events for applied objects (Phase 7) self._mlflow_log_apply_diffs(registry_diff) @@ -1637,6 +1641,10 @@ def apply( # Emit OpenLineage events for applied objects self._emit_openlineage_apply(objects) + # Register feature views as Unity Catalog feature tables (if using + # the databricks_uc offline store) + self._register_uc_feature_tables_legacy(objects) + # Emit MLflow events for applied objects (Phase 7) self._mlflow_log_apply(objects) @@ -1665,6 +1673,68 @@ def _emit_openlineage_apply(self, objects: List[Any]): except Exception as e: warnings.warn(f"Failed to emit OpenLineage apply events: {e}") + # ------------------------------------------------------------------ # + # Unity Catalog feature table registration hooks + # ------------------------------------------------------------------ # + + def _register_uc_feature_tables_from_diffs( + self, registry_diff: RegistryDiff + ) -> None: + """Register applied feature views as UC feature tables. + + Only active when the offline store is ``databricks_uc`` and + ``uc_registration.enabled`` is True. + """ + from feast.infra.offline_stores.contrib.spark_offline_store.databricks_uc import ( + DatabricksUCOfflineStoreConfig, + ) + from feast.infra.offline_stores.contrib.spark_offline_store.uc_registration import ( + register_uc_feature_tables, + ) + + if not isinstance(self.config.offline_store, DatabricksUCOfflineStoreConfig): + return + + uc_config = self.config.offline_store.uc_registration + if uc_config is None or not uc_config.enabled: + return + + fvs = [ + d.new_feast_object + for d in registry_diff.feast_object_diffs + if d.new_feast_object and isinstance(d.new_feast_object, FeatureView) + ] + if not fvs: + return + + register_uc_feature_tables(self.config.offline_store, fvs, self.project) + + def _register_uc_feature_tables_legacy(self, objects: List[Any]) -> None: + """Register feature views as UC feature tables (legacy apply path). + + Only active when the offline store is ``databricks_uc`` and + ``uc_registration.enabled`` is True. + """ + from feast.infra.offline_stores.contrib.spark_offline_store.databricks_uc import ( + DatabricksUCOfflineStoreConfig, + ) + from feast.infra.offline_stores.contrib.spark_offline_store.uc_registration import ( + register_uc_feature_tables, + ) + + if not isinstance(self.config.offline_store, DatabricksUCOfflineStoreConfig): + return + + uc_config = self.config.offline_store.uc_registration + if uc_config is None or not uc_config.enabled: + return + + fvs = [obj for obj in objects if isinstance(obj, FeatureView)] + if not fvs: + return + + register_uc_feature_tables(self.config.offline_store, fvs, self.project) + def teardown(self): """Tears down all local and cloud resources for the feature store.""" tables: List[BaseFeatureView] = [] diff --git a/sdk/python/feast/infra/compute_engines/local/nodes.py b/sdk/python/feast/infra/compute_engines/local/nodes.py index 9d3e1a48881..2e37863db4b 100644 --- a/sdk/python/feast/infra/compute_engines/local/nodes.py +++ b/sdk/python/feast/infra/compute_engines/local/nodes.py @@ -407,4 +407,16 @@ def execute(self, context: ExecutionContext) -> ArrowTableValue: progress=lambda x: None, ) + # UC-backed materialization hook (Phase L3) + from feast.infra.offline_stores.contrib.spark_offline_store.uc_registration import ( + write_uc_materialized_data, + ) + + write_uc_materialized_data( + config=context.repo_config, + fv=self.feature_view, + df=input_table, + project=context.repo_config.project, + ) + return output diff --git a/sdk/python/feast/infra/compute_engines/spark/nodes.py b/sdk/python/feast/infra/compute_engines/spark/nodes.py index 92964b72bc9..0df3b774973 100644 --- a/sdk/python/feast/infra/compute_engines/spark/nodes.py +++ b/sdk/python/feast/infra/compute_engines/spark/nodes.py @@ -590,6 +590,18 @@ def execute(self, context: ExecutionContext) -> DAGValue: ) spark_df.write.format(file_format).mode("append").save(dest_path) + # UC-backed materialization hook (Phase L3) + from feast.infra.offline_stores.contrib.spark_offline_store.uc_registration import ( + write_uc_materialized_data, + ) + + write_uc_materialized_data( + config=context.repo_config, + fv=self.feature_view, + df=spark_df, + project=context.repo_config.project, + ) + return DAGValue( data=spark_df, format=DAGFormat.SPARK, diff --git a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/databricks_uc.py b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/databricks_uc.py new file mode 100644 index 00000000000..3fd48b67d96 --- /dev/null +++ b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/databricks_uc.py @@ -0,0 +1,336 @@ +import logging +from datetime import date, datetime +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union + +import pandas as pd +import pyarrow +import pyspark +from pydantic import StrictBool, StrictStr +from pyspark import SparkConf +from pyspark.sql import SparkSession + +from feast import FeatureView +from feast.data_source import DataSource +from feast.infra.offline_stores.contrib.spark_offline_store.spark import ( + SparkOfflineStore, + SparkOfflineStoreConfig, +) +from feast.infra.offline_stores.offline_store import RetrievalJob +from feast.infra.registry.base_registry import BaseRegistry +from feast.repo_config import FeastConfigBaseModel, RepoConfig + +logger = logging.getLogger(__name__) + + +class UCRegistrationConfig(FeastConfigBaseModel): + """Configuration for Unity Catalog feature table registration during ``feast apply``.""" + + enabled: StrictBool = True + """Whether to register feature views as UC feature tables on ``feast apply``.""" + + catalog: Optional[StrictStr] = None + """Default catalog for UC feature tables. Overrides ``DatabricksUCOfflineStoreConfig.default_catalog``.""" + + uc_schema: Optional[StrictStr] = None + """Default schema for UC feature tables. Overrides ``DatabricksUCOfflineStoreConfig.default_schema``.""" + + +class DatabricksUCOfflineStoreConfig(SparkOfflineStoreConfig): + type: StrictStr = "databricks_uc" + """Offline store type selector""" + + workspace_host: Optional[StrictStr] = None + """Databricks workspace host (e.g. adb-xxxx.azuredatabricks.net)""" + + token: Optional[StrictStr] = None + """Databricks Personal Access Token (PAT)""" + + cluster_id: Optional[StrictStr] = None + """Databricks Cluster ID to connect to for Databricks Connect""" + + default_catalog: Optional[StrictStr] = None + """Default catalog name to use in Unity Catalog""" + + default_schema: Optional[StrictStr] = None + """Default schema name to use in Unity Catalog""" + + uc_registration: Optional[UCRegistrationConfig] = None + """Configuration for UC feature table registration during ``feast apply``.""" + + +def get_databricks_session( + store_config: DatabricksUCOfflineStoreConfig, +) -> SparkSession: + # Check if there is already an active session + spark_session = SparkSession.getActiveSession() + if not spark_session: + workspace_host = store_config.workspace_host + token = store_config.token + cluster_id = store_config.cluster_id + + # Clean host URL if it starts with https:// + if workspace_host: + if workspace_host.startswith("https://"): + workspace_host = workspace_host[8:] + elif workspace_host.startswith("http://"): + workspace_host = workspace_host[7:] + + if workspace_host and cluster_id: + # Databricks Connect V2 initialization (Spark Connect URI format) + conn_str = f"sc://{workspace_host}:443/" + params = [] + if token: + params.append(f"token={token}") + params.append(f"x-databricks-cluster-id={cluster_id}") + if params: + conn_str = f"{conn_str};{';'.join(params)}" + + try: + from databricks.connect import DatabricksSession + + builder = DatabricksSession.builder.remote(conn_str) + except ImportError: + # Fallback to standard PySpark remote connect if databricks-connect not installed + builder = SparkSession.builder.remote(conn_str) + else: + try: + from databricks.connect import DatabricksSession + + builder = DatabricksSession.builder + except ImportError: + builder = SparkSession.builder + + spark_conf = store_config.spark_conf + if spark_conf: + builder = builder.config( + conf=SparkConf().setAll([(k, v) for k, v in spark_conf.items()]) + ) + + spark_session = builder.getOrCreate() + + assert spark_session is not None + + # Apply configuration defaults + spark_session.conf.set("spark.sql.parser.quotedRegexColumnNames", "true") + + if store_config.default_catalog: + spark_session.sql(f"USE CATALOG `{store_config.default_catalog}`") + if store_config.default_schema: + spark_session.sql(f"USE SCHEMA `{store_config.default_schema}`") + + return spark_session + + +class DatabricksUCOfflineStore(SparkOfflineStore): + @staticmethod + def pull_latest_from_table_or_query( + config: RepoConfig, + data_source: DataSource, + join_key_columns: List[str], + feature_name_columns: List[str], + timestamp_field: str, + created_timestamp_column: Optional[str], + start_date: datetime, + end_date: datetime, + ) -> RetrievalJob: + assert isinstance(config.offline_store, DatabricksUCOfflineStoreConfig) + # Initialize/Retrieve the Databricks Spark Session so it's registered as active + get_databricks_session(config.offline_store) + + return SparkOfflineStore.pull_latest_from_table_or_query( + config=config, + data_source=data_source, + join_key_columns=join_key_columns, + feature_name_columns=feature_name_columns, + timestamp_field=timestamp_field, + created_timestamp_column=created_timestamp_column, + start_date=start_date, + end_date=end_date, + ) + + @staticmethod + def get_historical_features( + config: RepoConfig, + feature_views: List[FeatureView], + feature_refs: List[str], + entity_df: Optional[Union[pd.DataFrame, str, pyspark.sql.DataFrame]], + registry: BaseRegistry, + project: str, + full_feature_names: bool = False, + **kwargs, + ) -> RetrievalJob: + assert isinstance(config.offline_store, DatabricksUCOfflineStoreConfig) + get_databricks_session(config.offline_store) + + return SparkOfflineStore.get_historical_features( + config=config, + feature_views=feature_views, + feature_refs=feature_refs, + entity_df=entity_df, + registry=registry, + project=project, + full_feature_names=full_feature_names, + **kwargs, + ) + + @staticmethod + def pull_all_from_table_or_query( + config: RepoConfig, + data_source: DataSource, + join_key_columns: List[str], + feature_name_columns: List[str], + timestamp_field: str, + created_timestamp_column: Optional[str] = None, + start_date: Optional[datetime] = None, + end_date: Optional[datetime] = None, + ) -> RetrievalJob: + assert isinstance(config.offline_store, DatabricksUCOfflineStoreConfig) + get_databricks_session(config.offline_store) + + return SparkOfflineStore.pull_all_from_table_or_query( + config=config, + data_source=data_source, + join_key_columns=join_key_columns, + feature_name_columns=feature_name_columns, + timestamp_field=timestamp_field, + created_timestamp_column=created_timestamp_column, + start_date=start_date, + end_date=end_date, + ) + + @staticmethod + def offline_write_batch( + config: RepoConfig, + feature_view: FeatureView, + table: pyarrow.Table, + progress: Optional[Callable[[int], Any]], + ): + assert isinstance(config.offline_store, DatabricksUCOfflineStoreConfig) + get_databricks_session(config.offline_store) + + return SparkOfflineStore.offline_write_batch( + config=config, + feature_view=feature_view, + table=table, + progress=progress, + ) + + @staticmethod + def compute_monitoring_metrics( + config: RepoConfig, + data_source: DataSource, + feature_columns: List[Tuple[str, str]], + timestamp_field: str, + start_date: Optional[datetime] = None, + end_date: Optional[datetime] = None, + histogram_bins: int = 20, + top_n: int = 10, + ) -> List[Dict[str, Any]]: + assert isinstance(config.offline_store, DatabricksUCOfflineStoreConfig) + get_databricks_session(config.offline_store) + + return SparkOfflineStore.compute_monitoring_metrics( + config=config, + data_source=data_source, + feature_columns=feature_columns, + timestamp_field=timestamp_field, + start_date=start_date, + end_date=end_date, + histogram_bins=histogram_bins, + top_n=top_n, + ) + + @staticmethod + def get_monitoring_max_timestamp( + config: RepoConfig, + data_source: DataSource, + timestamp_field: str, + ) -> Optional[datetime]: + assert isinstance(config.offline_store, DatabricksUCOfflineStoreConfig) + get_databricks_session(config.offline_store) + + return SparkOfflineStore.get_monitoring_max_timestamp( + config=config, + data_source=data_source, + timestamp_field=timestamp_field, + ) + + @staticmethod + def ensure_monitoring_tables(config: RepoConfig) -> None: + assert isinstance(config.offline_store, DatabricksUCOfflineStoreConfig) + get_databricks_session(config.offline_store) + + return SparkOfflineStore.ensure_monitoring_tables(config=config) + + @staticmethod + def save_monitoring_metrics( + config: RepoConfig, + metric_type: str, + metrics: List[Dict[str, Any]], + ) -> None: + assert isinstance(config.offline_store, DatabricksUCOfflineStoreConfig) + get_databricks_session(config.offline_store) + + return SparkOfflineStore.save_monitoring_metrics( + config=config, + metric_type=metric_type, + metrics=metrics, + ) + + @staticmethod + def query_monitoring_metrics( + config: RepoConfig, + project: str, + metric_type: str, + filters: Optional[Dict[str, Any]] = None, + start_date: Optional[date] = None, + end_date: Optional[date] = None, + ) -> List[Dict[str, Any]]: + assert isinstance(config.offline_store, DatabricksUCOfflineStoreConfig) + get_databricks_session(config.offline_store) + + return SparkOfflineStore.query_monitoring_metrics( + config=config, + project=project, + metric_type=metric_type, + filters=filters, + start_date=start_date, + end_date=end_date, + ) + + @staticmethod + def clear_monitoring_baseline( + config: RepoConfig, + project: str, + feature_view_name: Optional[str] = None, + feature_name: Optional[str] = None, + data_source_type: Optional[str] = None, + ) -> None: + assert isinstance(config.offline_store, DatabricksUCOfflineStoreConfig) + get_databricks_session(config.offline_store) + + return SparkOfflineStore.clear_monitoring_baseline( + config=config, + project=project, + feature_view_name=feature_view_name, + feature_name=feature_name, + data_source_type=data_source_type, + ) + + @staticmethod + def validate_data_source( + config: RepoConfig, + data_source: DataSource, + ): + assert isinstance(config.offline_store, DatabricksUCOfflineStoreConfig) + get_databricks_session(config.offline_store) + data_source.validate(config=config) + + @staticmethod + def get_table_column_names_and_types_from_data_source( + config: RepoConfig, + data_source: DataSource, + ) -> Iterable[Tuple[str, str]]: + assert isinstance(config.offline_store, DatabricksUCOfflineStoreConfig) + get_databricks_session(config.offline_store) + return data_source.get_table_column_names_and_types(config=config) diff --git a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/uc_registration.py b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/uc_registration.py new file mode 100644 index 00000000000..a4b283f4a49 --- /dev/null +++ b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/uc_registration.py @@ -0,0 +1,491 @@ +"""Unity Catalog feature table registration for ``feast apply``. + +When the offline store is configured as ``databricks_uc``, this module +registers (or updates) each FeatureView as a Unity Catalog feature table +via the Databricks ``FeatureEngineeringClient``. FeatureView entities +become UC primary keys, and tags/description/owner are synced to UC metadata. + +Per‑FeatureView opt‑out and overrides are controlled via FeatureView ``tags``: + +* ``uc.register_as_feature_table`` — ``"false"`` skips a specific view. +* ``uc.catalog`` / ``uc.schema`` / ``uc.table`` — override the UC path. + +Global defaults come from ``UCRegistrationConfig`` (inside +``DatabricksUCOfflineStoreConfig``). +""" + +import logging +from typing import Any, Dict, List, Optional, Tuple + +import click +from pyspark.sql import SparkSession +from pyspark.sql.types import ( + StructField, + StructType, +) + +from feast import FeatureView +from feast.infra.offline_stores.contrib.spark_offline_store.databricks_uc import ( + DatabricksUCOfflineStoreConfig, + get_databricks_session, +) +from feast.repo_config import RepoConfig + +logger = logging.getLogger(__name__) + +# FeatureView tag keys for per‑FV UC configuration +_REGISTER_AS_FEATURE_TABLE_KEY = "uc.register_as_feature_table" +_CATALOG_KEY = "uc.catalog" +_SCHEMA_KEY = "uc.schema" +_TABLE_KEY = "uc.table" + +# Internal Feast tags stored on the UC table +_MANAGED_BY_TAG = "feast_managed" + + +def _feast_to_spark_type_simple(field): + """Convert a Feast :class:`Field` to a pyspark :class:`DataType`. + + This is a best‑effort mapping; the underlying Spark‑read path performs + full schema inference at query time. + """ + from pyspark.sql.types import ( + ArrayType, + BinaryType, + BooleanType, + ByteType, + DateType, + DoubleType, + FloatType, + IntegerType, + LongType, + ShortType, + StringType, + TimestampType, + ) + + dtype = getattr(field, "dtype", None) + if dtype is None: + return StringType() # fallback + + type_name = str(dtype).upper() + + type_map = { + "BOOL": BooleanType(), + "BOOLEAN": BooleanType(), + "INT8": ByteType(), + "BYTE": ByteType(), + "INT16": ShortType(), + "SHORT": ShortType(), + "INT32": IntegerType(), + "INT": IntegerType(), + "INTEGER": IntegerType(), + "INT64": LongType(), + "LONG": LongType(), + "BIGINT": LongType(), + "FLOAT32": FloatType(), + "FLOAT": FloatType(), + "FLOAT64": DoubleType(), + "DOUBLE": DoubleType(), + "STRING": StringType(), + "UTF8": StringType(), + "BINARY": BinaryType(), + "BYTES": BinaryType(), + "TIMESTAMP": TimestampType(), + "TIMESTAMP_TZ": TimestampType(), + "UNIXTIMESTAMP": TimestampType(), + "DATE": DateType(), + "DATE32": DateType(), + } + + if type_name.startswith("LIST<") or type_name.startswith("ARRAY<"): + return ArrayType(StringType()) + if "DOUBLE" in type_name or ("FLOAT" in type_name and "64" in type_name): + return DoubleType() + if "LIST" in type_name or "ARRAY" in type_name: + return ArrayType(StringType()) + + return type_map.get(type_name, StringType()) + + +def _should_register(fv: FeatureView) -> bool: + """Return ``True`` unless the feature view opts out via tags.""" + return fv.tags.get(_REGISTER_AS_FEATURE_TABLE_KEY, "true").lower() != "false" + + +def _resolve_uc_path( + fv: FeatureView, + default_catalog: Optional[str], + default_schema: Optional[str], +) -> Tuple[Optional[str], Optional[str], str]: + """Resolve the (catalog, schema, table_name) for a feature view. + + Prioritises per‑FV tag overrides, then global defaults, then the + feature view name as the table name. + """ + catalog = fv.tags.get(_CATALOG_KEY) or default_catalog + schema = fv.tags.get(_SCHEMA_KEY) or default_schema + table = fv.tags.get(_TABLE_KEY) or fv.name + + # Sanitise: replace characters that are invalid in UC names + table = table.replace("-", "_").replace(".", "_").replace(" ", "_") + return catalog, schema, table + + +def _build_spark_schema(fv: FeatureView) -> StructType: + """Build a pyspark ``StructType`` from the feature view's columns.""" + fields: List[StructField] = [] + + seen: set = set() + for col in fv.entity_columns: + if col.name not in seen: + fields.append( + StructField(col.name, _feast_to_spark_type_simple(col), nullable=False) + ) + seen.add(col.name) + + for col in fv.features: + if col.name not in seen: + fields.append( + StructField(col.name, _feast_to_spark_type_simple(col), nullable=True) + ) + seen.add(col.name) + + # Add timestamp column if the batch source declares one + timestamp_field = getattr(fv.batch_source, "timestamp_field", None) + if timestamp_field and timestamp_field not in seen: + from pyspark.sql.types import TimestampType + + fields.append(StructField(timestamp_field, TimestampType(), nullable=True)) + + return StructType(fields) + + +def _get_primary_keys(fv: FeatureView) -> List[str]: + """Extract primary key column names from entity columns.""" + return [col.name for col in fv.entity_columns] + + +def _build_uc_tags(fv: FeatureView, project: str) -> Dict[str, str]: + """Build UC table tags, excluding internal ``uc.*`` keys.""" + tags: Dict[str, str] = {} + for key, value in fv.tags.items(): + if not key.startswith("uc."): + tags[key] = value + tags[_MANAGED_BY_TAG] = "feast" + tags["feast_project"] = project + return tags + + +def _escape_sql_string(value: str) -> str: + """Escape single quotes for SQL string literals.""" + return value.replace("\\", "\\\\").replace("'", "\\'") + + +def _build_full_table_name(catalog, schema, table): + """Build a three-level table reference: ``catalog.schema.table``.""" + parts = [] + if catalog: + parts.append(f"`{catalog}`") + if schema: + parts.append(f"`{schema}`") + parts.append(f"`{table}`") + return ".".join(parts) + + +def _table_exists(spark_session: SparkSession, full_name: str) -> bool: + """Check whether a UC table exists.""" + try: + return spark_session.catalog.tableExists(full_name) + except Exception: + return False + + +def _create_uc_feature_table( + fe_client, + spark_session: SparkSession, + fv: FeatureView, + full_name: str, + primary_keys: List[str], + project: str, +) -> None: + """Create a new UC feature table via FeatureEngineeringClient.""" + spark_schema = _build_spark_schema(fv) + timestamp_key = getattr(fv.batch_source, "timestamp_field", None) + description = fv.description or "" + tags = _build_uc_tags(fv, project) + owner = fv.owner + + fe_client.create_table( + name=full_name, + primary_keys=primary_keys, + schema=spark_schema, + timestamp_key=timestamp_key, + description=description, + tags=tags, + ) + + if owner: + try: + spark_session.sql( + f"ALTER TABLE {full_name} SET OWNER TO `{_escape_sql_string(owner)}`" + ) + except Exception: + logger.debug( + "Could not set owner for UC table %s; continuing", + full_name, + ) + + +def _update_uc_feature_table_metadata( + spark_session: SparkSession, + fv: FeatureView, + full_name: str, + project: str, +) -> None: + """Update metadata on an existing UC feature table.""" + description = fv.description or "" + owner = fv.owner + tags = _build_uc_tags(fv, project) + + # Update description (COMMENT) + try: + spark_session.sql( + f"COMMENT ON TABLE {full_name} IS '{_escape_sql_string(description)}'" + ) + except Exception: + logger.debug( + "Could not update comment on UC table %s; continuing", + full_name, + ) + + # Update tags + if tags: + tag_pairs = ", ".join( + f"'{_escape_sql_string(k)}' = '{_escape_sql_string(v)}'" + for k, v in tags.items() + ) + try: + spark_session.sql(f"ALTER TABLE {full_name} SET TAGS ({tag_pairs})") + except Exception: + logger.debug( + "Could not update tags on UC table %s; continuing", + full_name, + ) + + # Update owner + if owner: + try: + spark_session.sql( + f"ALTER TABLE {full_name} SET OWNER TO `{_escape_sql_string(owner)}`" + ) + except Exception: + logger.debug( + "Could not update owner on UC table %s; continuing", + full_name, + ) + + +def _register_single_feature_view( + spark_session: SparkSession, + fe_client, + fv: FeatureView, + project: str, + default_catalog: Optional[str], + default_schema: Optional[str], +) -> None: + """Register or update a single FeatureView as a UC feature table.""" + if not _should_register(fv): + click.echo(f" ⊘ Skipping UC registration for {fv.name} (opt‑out)") + return + + catalog, schema, table = _resolve_uc_path(fv, default_catalog, default_schema) + if not catalog or not schema: + click.echo( + f" ⚠ Cannot register {fv.name}: missing catalog or schema. " + f"Set default_catalog/default_schema in feature_store.yaml or " + f"use tags uc.catalog/uc.schema." + ) + return + + full_name = _build_full_table_name(catalog, schema, table) + primary_keys = _get_primary_keys(fv) + + if _table_exists(spark_session, full_name): + _update_uc_feature_table_metadata(spark_session, fv, full_name, project) + click.echo(f" ✓ Updated UC feature table: {full_name}") + else: + _create_uc_feature_table( + fe_client, + spark_session, + fv, + full_name, + primary_keys, + project, + ) + click.echo(f" ✓ Created UC feature table: {full_name}") + + +def register_uc_feature_tables( + config: DatabricksUCOfflineStoreConfig, + feature_views: List[FeatureView], + project: str, +) -> None: + """Register or update FeatureViews as Unity Catalog feature tables. + + Skips silently when: + - ``uc_registration`` is not configured or ``enabled`` is ``False``. + - ``databricks-feature-engineering`` is not installed. + - The Databricks Spark session cannot be created. + + Per‑feature‑view errors are logged and do not halt the apply. + """ + uc_config = config.uc_registration + if uc_config is None or not uc_config.enabled: + return + + try: + from databricks.feature_engineering import ( + FeatureEngineeringClient, # noqa: F401 + ) + except ImportError: + logger.info( + "databricks-feature-engineering is not installed; " + "skipping UC feature table registration. " + "Install with: pip install databricks-feature-engineering" + ) + return + + try: + spark_session = get_databricks_session(config) + except Exception as e: + logger.warning( + "Could not create Databricks Spark session for UC registration: %s", e + ) + return + + fe_client = FeatureEngineeringClient() + + default_catalog = uc_config.catalog or config.default_catalog + default_schema = uc_config.uc_schema or config.default_schema + + for fv in feature_views: + try: + _register_single_feature_view( + spark_session, + fe_client, + fv, + project, + default_catalog, + default_schema, + ) + except Exception: + logger.exception( + "Failed to register UC feature table for FeatureView '%s'", + fv.name, + ) + click.echo( + f" ✗ Failed to register UC feature table: {fv.name} " + f"(check logs for details)" + ) + + +def write_uc_materialized_data( + config: RepoConfig, + fv: FeatureView, + df: Any, + project: str, +) -> None: + """Write materialized features into the Unity Catalog feature table. + + Only active when the offline store is ``databricks_uc`` and + ``uc_registration.enabled`` is True. + """ + if not isinstance(config.offline_store, DatabricksUCOfflineStoreConfig): + return + + uc_config = config.offline_store.uc_registration + if uc_config is None or not uc_config.enabled: + return + + # Check for per-FeatureView opt-out tags + if not _should_register(fv): + logger.info(f"Skipping UC materialization for {fv.name} (opt-out)") + return + + if fv.tags.get("uc.materialize_offline", "true").lower() == "false": + logger.info( + f"Skipping UC materialization for {fv.name} (materialize_offline=false)" + ) + return + + try: + from databricks.feature_engineering import FeatureEngineeringClient + except ImportError: + logger.warning( + "databricks-feature-engineering is not installed; " + "skipping UC-backed materialization." + ) + return + + # Resolve UC path + default_catalog = uc_config.catalog or config.offline_store.default_catalog + default_schema = uc_config.uc_schema or config.offline_store.default_schema + catalog, schema, table = _resolve_uc_path(fv, default_catalog, default_schema) + if not catalog or not schema: + logger.warning( + f"Cannot materialize to UC for {fv.name}: missing catalog or schema." + ) + return + + full_name = _build_full_table_name(catalog, schema, table) + + # Get Spark session and construct Spark DataFrame if needed + try: + spark_session = get_databricks_session(config.offline_store) + except Exception as e: + logger.warning( + "Could not create Databricks Spark session for UC materialization: %s", e + ) + return + + import pyarrow as pa + + if isinstance(df, pa.Table): + # Convert pyarrow Table to pandas, then to Spark DataFrame + spark_df = spark_session.createDataFrame(df.to_pandas()) + else: + spark_df = df + + fe_client = FeatureEngineeringClient() + + # If the feature table does not exist in UC, register/create it first + if not _table_exists(spark_session, full_name): + primary_keys = _get_primary_keys(fv) + try: + _create_uc_feature_table( + fe_client, + spark_session, + fv, + full_name, + primary_keys, + project, + ) + logger.info(f"Created UC feature table: {full_name}") + except Exception as e: + logger.exception( + f"Failed to create UC feature table {full_name} during materialization" + ) + raise e + + # Write/merge into the UC table + try: + fe_client.write_table( + name=full_name, + df=spark_df, + mode="merge", + ) + logger.info(f"Successfully materialized features to UC table: {full_name}") + except Exception as e: + logger.exception(f"Failed to write/merge to UC feature table {full_name}") + raise e diff --git a/sdk/python/feast/repo_config.py b/sdk/python/feast/repo_config.py index 971c0325f4b..847f933d4f6 100644 --- a/sdk/python/feast/repo_config.py +++ b/sdk/python/feast/repo_config.py @@ -96,6 +96,7 @@ "redshift": "feast.infra.offline_stores.redshift.RedshiftOfflineStore", "snowflake.offline": "feast.infra.offline_stores.snowflake.SnowflakeOfflineStore", "spark": "feast.infra.offline_stores.contrib.spark_offline_store.spark.SparkOfflineStore", + "databricks_uc": "feast.infra.offline_stores.contrib.spark_offline_store.databricks_uc.DatabricksUCOfflineStore", "trino": "feast.infra.offline_stores.contrib.trino_offline_store.trino.TrinoOfflineStore", "postgres": "feast.infra.offline_stores.contrib.postgres_offline_store.postgres.PostgreSQLOfflineStore", "athena": "feast.infra.offline_stores.contrib.athena_offline_store.athena.AthenaOfflineStore", diff --git a/sdk/python/tests/unit/infra/offline_stores/contrib/spark_offline_store/test_databricks_uc.py b/sdk/python/tests/unit/infra/offline_stores/contrib/spark_offline_store/test_databricks_uc.py new file mode 100644 index 00000000000..07bfd1d5416 --- /dev/null +++ b/sdk/python/tests/unit/infra/offline_stores/contrib/spark_offline_store/test_databricks_uc.py @@ -0,0 +1,193 @@ +from datetime import datetime, timezone +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError + +from feast.infra.offline_stores.contrib.spark_offline_store.databricks_uc import ( + DatabricksUCOfflineStore, + DatabricksUCOfflineStoreConfig, + get_databricks_session, +) +from feast.infra.offline_stores.contrib.spark_offline_store.spark_source import ( + SparkSource, +) +from feast.infra.online_stores.sqlite import SqliteOnlineStoreConfig +from feast.repo_config import RepoConfig + + +def test_config_parsing(): + config_dict = { + "type": "databricks_uc", + "workspace_host": "adb-12345.azuredatabricks.net", + "token": "dapi123456", + "cluster_id": "0123-4567-abcde", + "default_catalog": "main", + "default_schema": "default", + "spark_conf": {"spark.sql.shuffle.partitions": "10"}, + } + config = DatabricksUCOfflineStoreConfig(**config_dict) + assert config.type == "databricks_uc" + assert config.workspace_host == "adb-12345.azuredatabricks.net" + assert config.token == "dapi123456" + assert config.cluster_id == "0123-4567-abcde" + assert config.default_catalog == "main" + assert config.default_schema == "default" + assert config.spark_conf == {"spark.sql.shuffle.partitions": "10"} + + +def test_config_forbidden_extra(): + with pytest.raises(ValidationError): + DatabricksUCOfflineStoreConfig(type="databricks_uc", invalid_key="some_val") + + +@patch("pyspark.sql.SparkSession.getActiveSession") +def test_get_databricks_session_active(mock_get_active): + mock_session = MagicMock() + mock_get_active.return_value = mock_session + + config = DatabricksUCOfflineStoreConfig( + type="databricks_uc", + default_catalog="my_catalog", + default_schema="my_schema", + ) + + session = get_databricks_session(config) + + assert session == mock_session + mock_session.conf.set.assert_called_once_with( + "spark.sql.parser.quotedRegexColumnNames", "true" + ) + mock_session.sql.assert_any_call("USE CATALOG `my_catalog`") + mock_session.sql.assert_any_call("USE SCHEMA `my_schema`") + + +@patch("pyspark.sql.SparkSession.getActiveSession") +@patch("pyspark.sql.SparkSession.builder") +def test_get_databricks_session_new_remote(mock_builder, mock_get_active): + mock_get_active.return_value = None + mock_session = MagicMock() + mock_builder.remote.return_value.config.return_value.getOrCreate.return_value = ( + mock_session + ) + + config = DatabricksUCOfflineStoreConfig( + type="databricks_uc", + workspace_host="https://adb-12345.azuredatabricks.net", + token="dapi123", + cluster_id="0123-4567-abcde", + spark_conf={"spark.some.option": "value"}, + ) + + session = get_databricks_session(config) + + assert session == mock_session + mock_builder.remote.assert_called_once_with( + "sc://adb-12345.azuredatabricks.net:443/;token=dapi123;x-databricks-cluster-id=0123-4567-abcde" + ) + + +@patch( + "feast.infra.offline_stores.contrib.spark_offline_store.databricks_uc.get_databricks_session" +) +@patch( + "feast.infra.offline_stores.contrib.spark_offline_store.spark.SparkOfflineStore.get_historical_features" +) +def test_get_historical_features_delegation(mock_parent_features, mock_get_session): + mock_session = MagicMock() + mock_get_session.return_value = mock_session + + repo_config = RepoConfig( + registry="file:///tmp/registry.db", + project="test", + provider="local", + online_store=SqliteOnlineStoreConfig(type="sqlite"), + offline_store=DatabricksUCOfflineStoreConfig( + type="databricks_uc", + workspace_host="adb-123.databricks.com", + cluster_id="123", + ), + ) + + feature_views = [] + feature_refs = ["fv:f1"] + entity_df = MagicMock() + registry = MagicMock() + + DatabricksUCOfflineStore.get_historical_features( + config=repo_config, + feature_views=feature_views, + feature_refs=feature_refs, + entity_df=entity_df, + registry=registry, + project="test", + ) + + mock_get_session.assert_called_once_with(repo_config.offline_store) + mock_parent_features.assert_called_once_with( + config=repo_config, + feature_views=feature_views, + feature_refs=feature_refs, + entity_df=entity_df, + registry=registry, + project="test", + full_feature_names=False, + ) + + +@patch( + "feast.infra.offline_stores.contrib.spark_offline_store.databricks_uc.get_databricks_session" +) +@patch( + "feast.infra.offline_stores.contrib.spark_offline_store.spark.SparkOfflineStore.pull_latest_from_table_or_query" +) +def test_pull_latest_from_table_or_query_delegation( + mock_parent_pull_latest, mock_get_session +): + mock_session = MagicMock() + mock_get_session.return_value = mock_session + + repo_config = RepoConfig( + registry="file:///tmp/registry.db", + project="test", + provider="local", + online_store=SqliteOnlineStoreConfig(type="sqlite"), + offline_store=DatabricksUCOfflineStoreConfig( + type="databricks_uc", + workspace_host="adb-123.databricks.com", + cluster_id="123", + ), + ) + + data_source = SparkSource( + name="test_source", + path="catalog.schema.table", + file_format="parquet", + timestamp_field="ts", + ) + + start_date = datetime(2023, 1, 1, tzinfo=timezone.utc) + end_date = datetime(2023, 1, 2, tzinfo=timezone.utc) + + DatabricksUCOfflineStore.pull_latest_from_table_or_query( + config=repo_config, + data_source=data_source, + join_key_columns=["id"], + feature_name_columns=["val"], + timestamp_field="ts", + created_timestamp_column=None, + start_date=start_date, + end_date=end_date, + ) + + mock_get_session.assert_called_once_with(repo_config.offline_store) + mock_parent_pull_latest.assert_called_once_with( + config=repo_config, + data_source=data_source, + join_key_columns=["id"], + feature_name_columns=["val"], + timestamp_field="ts", + created_timestamp_column=None, + start_date=start_date, + end_date=end_date, + ) diff --git a/sdk/python/tests/unit/infra/offline_stores/contrib/spark_offline_store/test_uc_registration.py b/sdk/python/tests/unit/infra/offline_stores/contrib/spark_offline_store/test_uc_registration.py new file mode 100644 index 00000000000..e49868f8d23 --- /dev/null +++ b/sdk/python/tests/unit/infra/offline_stores/contrib/spark_offline_store/test_uc_registration.py @@ -0,0 +1,639 @@ +from unittest.mock import MagicMock, patch + +from pyspark.sql.types import ( + ArrayType, + BinaryType, + BooleanType, + DoubleType, + FloatType, + IntegerType, + LongType, + StringType, + StructType, + TimestampType, +) + +from feast import FeatureView +from feast.infra.offline_stores.contrib.spark_offline_store.databricks_uc import ( + DatabricksUCOfflineStoreConfig, + UCRegistrationConfig, +) +from feast.infra.offline_stores.contrib.spark_offline_store.uc_registration import ( + _build_full_table_name, + _build_spark_schema, + _build_uc_tags, + _feast_to_spark_type_simple, + _get_primary_keys, + _resolve_uc_path, + _should_register, + register_uc_feature_tables, + write_uc_materialized_data, +) +from feast.repo_config import RepoConfig + + +def make_mock_field(name: str, dtype_str: str = "", nullable: bool = True): + field = MagicMock() + field.name = name + # _feast_to_spark_type_simple does str(dtype).upper() + # When dtype_str is empty, dtype will be None -> StringType fallback + if dtype_str: + + class FakeDtype: + def __str__(self): + return dtype_str + + field.dtype = FakeDtype() + else: + field.dtype = None + return field + + +def test_feast_to_spark_type_simple_none(): + field = make_mock_field("col", "") + result = _feast_to_spark_type_simple(field) + assert isinstance(result, StringType) + + +def test_feast_to_spark_type_simple_int32(): + field = make_mock_field("col", "Int32") + result = _feast_to_spark_type_simple(field) + assert isinstance(result, IntegerType) + + +def test_feast_to_spark_type_simple_int64(): + field = make_mock_field("col", "Int64") + result = _feast_to_spark_type_simple(field) + assert isinstance(result, LongType) + + +def test_feast_to_spark_type_simple_float32(): + field = make_mock_field("col", "Float32") + result = _feast_to_spark_type_simple(field) + assert isinstance(result, FloatType) + + +def test_feast_to_spark_type_simple_float64(): + field = make_mock_field("col", "Float64") + result = _feast_to_spark_type_simple(field) + assert isinstance(result, DoubleType) + + +def test_feast_to_spark_type_simple_string(): + field = make_mock_field("col", "String") + result = _feast_to_spark_type_simple(field) + assert isinstance(result, StringType) + + +def test_feast_to_spark_type_simple_bool(): + field = make_mock_field("col", "Bool") + result = _feast_to_spark_type_simple(field) + assert isinstance(result, BooleanType) + + +def test_feast_to_spark_type_simple_bytes(): + field = make_mock_field("col", "Bytes") + result = _feast_to_spark_type_simple(field) + assert isinstance(result, BinaryType) + + +def test_feast_to_spark_type_simple_unix_timestamp(): + field = make_mock_field("col", "UnixTimestamp") + result = _feast_to_spark_type_simple(field) + assert isinstance(result, TimestampType) + + +def test_feast_to_spark_type_simple_list(): + field = make_mock_field("col", "List") + result = _feast_to_spark_type_simple(field) + assert isinstance(result, ArrayType) + + +def test_feast_to_spark_type_simple_array(): + field = make_mock_field("col", "Array") + result = _feast_to_spark_type_simple(field) + assert isinstance(result, ArrayType) + + +def test_feast_to_spark_type_simple_unknown(): + field = make_mock_field("col", "UnknownType") + result = _feast_to_spark_type_simple(field) + assert isinstance(result, StringType) + + +def test_should_register_default(): + fv = MagicMock(spec=FeatureView) + fv.tags = {} + assert _should_register(fv) is True + + +def test_should_register_true(): + fv = MagicMock(spec=FeatureView) + fv.tags = {"uc.register_as_feature_table": "true"} + assert _should_register(fv) is True + + +def test_should_register_false(): + fv = MagicMock(spec=FeatureView) + fv.tags = {"uc.register_as_feature_table": "false"} + assert _should_register(fv) is False + + +def test_should_register_case_insensitive(): + fv = MagicMock(spec=FeatureView) + fv.tags = {"uc.register_as_feature_table": "FALSE"} + assert _should_register(fv) is False + + +def test_resolve_uc_path_uses_defaults(): + fv = MagicMock(spec=FeatureView) + fv.tags = {} + fv.name = "my_feature_view" + catalog, schema, table = _resolve_uc_path(fv, "default_cat", "default_sch") + assert catalog == "default_cat" + assert schema == "default_sch" + assert table == "my_feature_view" + + +def test_resolve_uc_path_uses_tags(): + fv = MagicMock(spec=FeatureView) + fv.tags = { + "uc.catalog": "tag_cat", + "uc.schema": "tag_sch", + "uc.table": "tag_table", + } + fv.name = "ignored" + catalog, schema, table = _resolve_uc_path(fv, "default_cat", "default_sch") + assert catalog == "tag_cat" + assert schema == "tag_sch" + assert table == "tag_table" + + +def test_resolve_uc_path_sanitizes_table_name(): + fv = MagicMock(spec=FeatureView) + fv.tags = {} + fv.name = "my-feature.view" + catalog, schema, table = _resolve_uc_path(fv, "cat", "sch") + assert table == "my_feature_view" + + +def test_resolve_uc_path_tag_overrides_none_default(): + fv = MagicMock(spec=FeatureView) + fv.tags = {"uc.catalog": "tag_cat"} + fv.name = "fv_name" + catalog, schema, table = _resolve_uc_path(fv, None, None) + assert catalog == "tag_cat" + assert schema is None + assert table == "fv_name" + + +def test_build_uc_tags(): + fv = MagicMock(spec=FeatureView) + fv.tags = { + "env": "prod", + "uc.register_as_feature_table": "false", + "owner": "team_a", + } + tags = _build_uc_tags(fv, "my_project") + assert tags["env"] == "prod" + assert tags["owner"] == "team_a" + assert "uc.register_as_feature_table" not in tags + assert tags["feast_managed"] == "feast" + assert tags["feast_project"] == "my_project" + + +def test_build_uc_tags_empty_tags(): + fv = MagicMock(spec=FeatureView) + fv.tags = {} + tags = _build_uc_tags(fv, "proj") + assert tags["feast_managed"] == "feast" + assert tags["feast_project"] == "proj" + + +def test_get_primary_keys(): + fv = MagicMock(spec=FeatureView) + col1 = make_mock_field("id", "Int32") + col2 = make_mock_field("date", "String") + fv.entity_columns = [col1, col2] + assert _get_primary_keys(fv) == ["id", "date"] + + +def test_get_primary_keys_empty(): + fv = MagicMock(spec=FeatureView) + fv.entity_columns = [] + assert _get_primary_keys(fv) == [] + + +def test_build_full_table_name_all_parts(): + assert _build_full_table_name("cat", "sch", "table") == "`cat`.`sch`.`table`" + + +def test_build_full_table_name_no_catalog(): + assert _build_full_table_name(None, "sch", "table") == "`sch`.`table`" + + +def test_build_full_table_name_no_schema(): + assert _build_full_table_name("cat", None, "table") == "`cat`.`table`" + + +def test_build_full_table_name_only_table(): + assert _build_full_table_name(None, None, "table") == "`table`" + + +def test_build_spark_schema(): + fv = MagicMock(spec=FeatureView) + id_field = make_mock_field("entity_id", "Int32") + feat_field = make_mock_field("feature_val", "Float64") + fv.entity_columns = [id_field] + fv.features = [feat_field] + fv.batch_source = MagicMock() + fv.batch_source.timestamp_field = "event_ts" + + schema = _build_spark_schema(fv) + + assert isinstance(schema, StructType) + fields = schema.fields + assert fields[0].name == "entity_id" + assert isinstance(fields[0].dataType, IntegerType) + assert fields[0].nullable is False + + assert fields[1].name == "feature_val" + assert isinstance(fields[1].dataType, DoubleType) + assert fields[1].nullable is True + + assert fields[2].name == "event_ts" + assert isinstance(fields[2].dataType, TimestampType) + assert fields[2].nullable is True + + +def test_build_spark_schema_no_timestamp(): + fv = MagicMock(spec=FeatureView) + id_field = make_mock_field("id", "Int64") + fv.entity_columns = [id_field] + fv.features = [] + fv.batch_source = MagicMock() + fv.batch_source.timestamp_field = None + + schema = _build_spark_schema(fv) + assert len(schema.fields) == 1 + assert schema.fields[0].name == "id" + + +def test_register_uc_feature_tables_skips_when_disabled(): + config = DatabricksUCOfflineStoreConfig(type="databricks_uc") + register_uc_feature_tables(config, [], "test_project") + # No error means success (skip silently) + + +def test_register_uc_feature_tables_skips_when_no_uc_config(): + config = DatabricksUCOfflineStoreConfig( + type="databricks_uc", + uc_registration=UCRegistrationConfig(enabled=False), + ) + register_uc_feature_tables(config, [], "test_project") + # No error means success + + +def _run_with_mock_fe(config, fvs, project, table_exists: bool = False): + """Helper to run register_uc_feature_tables with mocked external deps.""" + fe_module = MagicMock() + fe_module.FeatureEngineeringClient = MagicMock() + with patch.dict("sys.modules", {"databricks.feature_engineering": fe_module}): + with patch( + "feast.infra.offline_stores.contrib.spark_offline_store.uc_registration.get_databricks_session" + ) as mock_get_session: + fe_client = MagicMock() + fe_module.FeatureEngineeringClient.return_value = fe_client + mock_session = MagicMock() + mock_session.catalog.tableExists.return_value = table_exists + mock_get_session.return_value = mock_session + register_uc_feature_tables(config, fvs, project) + return fe_client, mock_session + + +def test_register_uc_feature_tables_creates_new_table(): + config = DatabricksUCOfflineStoreConfig( + type="databricks_uc", + default_catalog="cat", + default_schema="sch", + uc_registration=UCRegistrationConfig(enabled=True), + ) + + fv = MagicMock(spec=FeatureView) + fv.tags = {} + fv.name = "test_fv" + fv.entity_columns = [] + fv.features = [] + fv.batch_source = MagicMock() + fv.batch_source.timestamp_field = None + fv.description = "test description" + fv.owner = "user1" + + fe_client, mock_session = _run_with_mock_fe( + config, [fv], "proj", table_exists=False + ) + + fe_client.create_table.assert_called_once() + call_kwargs = fe_client.create_table.call_args[1] + assert call_kwargs["name"] == "`cat`.`sch`.`test_fv`" + assert call_kwargs["primary_keys"] == [] + assert call_kwargs["description"] == "test description" + + +@patch( + "feast.infra.offline_stores.contrib.spark_offline_store.uc_registration.get_databricks_session" +) +def test_register_uc_feature_tables_updates_existing( + mock_get_session, +): + mock_session = MagicMock() + mock_get_session.return_value = mock_session + mock_session.catalog.tableExists.return_value = True + + fe_module = MagicMock() + fe_client = MagicMock() + fe_module.FeatureEngineeringClient.return_value = fe_client + + config = DatabricksUCOfflineStoreConfig( + type="databricks_uc", + default_catalog="cat", + default_schema="sch", + uc_registration=UCRegistrationConfig(enabled=True), + ) + + fv = MagicMock(spec=FeatureView) + fv.tags = {} + fv.name = "test_fv" + fv.entity_columns = [] + fv.features = [] + fv.batch_source = MagicMock() + fv.batch_source.timestamp_field = None + fv.description = "test description" + fv.owner = "user1" + + with patch.dict("sys.modules", {"databricks.feature_engineering": fe_module}): + register_uc_feature_tables(config, [fv], "proj") + + fe_client.create_table.assert_not_called() + mock_session.sql.assert_any_call( + "COMMENT ON TABLE `cat`.`sch`.`test_fv` IS 'test description'" + ) + + +@patch( + "feast.infra.offline_stores.contrib.spark_offline_store.uc_registration.get_databricks_session" +) +def test_register_uc_feature_tables_skips_opt_out( + mock_get_session, +): + mock_session = MagicMock() + mock_get_session.return_value = mock_session + mock_session.catalog.tableExists.return_value = False + + config = DatabricksUCOfflineStoreConfig( + type="databricks_uc", + default_catalog="cat", + default_schema="sch", + uc_registration=UCRegistrationConfig(enabled=True), + ) + + fv = MagicMock(spec=FeatureView) + fv.tags = {"uc.register_as_feature_table": "false"} + fv.name = "opt_out_fv" + fv.entity_columns = [] + fv.features = [] + fv.batch_source = MagicMock() + fv.batch_source.timestamp_field = None + fv.description = "" + fv.owner = None + + fe_module = MagicMock() + fe_client = MagicMock() + fe_module.FeatureEngineeringClient.return_value = fe_client + with patch.dict("sys.modules", {"databricks.feature_engineering": fe_module}): + register_uc_feature_tables(config, [fv], "proj") + + # Per-view opt-out means create_table should NOT be called + fe_client.create_table.assert_not_called() + + +def test_register_uc_feature_tables_skips_when_missing_catalog(): + fe_module = MagicMock() + fe_client = MagicMock() + fe_module.FeatureEngineeringClient.return_value = fe_client + with patch.dict("sys.modules", {"databricks.feature_engineering": fe_module}): + with patch( + "feast.infra.offline_stores.contrib.spark_offline_store.uc_registration.get_databricks_session" + ) as mock_get_session: + mock_session = MagicMock() + mock_get_session.return_value = mock_session + + config = DatabricksUCOfflineStoreConfig( + type="databricks_uc", + default_catalog=None, + default_schema=None, + uc_registration=UCRegistrationConfig(enabled=True), + ) + + fv = MagicMock(spec=FeatureView) + fv.tags = {} + fv.name = "no_cat_fv" + fv.entity_columns = [] + fv.features = [] + fv.batch_source = MagicMock() + fv.batch_source.timestamp_field = None + fv.description = "" + fv.owner = None + + register_uc_feature_tables(config, [fv], "proj") + + fe_client.create_table.assert_not_called() + + +# ────────────────────────────────────────────── +# Tests for write_uc_materialized_data (L3) +# ────────────────────────────────────────────── + + +def _run_write_with_mock( + config, + fv, + df, + project, + table_exists: bool = False, + mock_spark_session=None, +): + """Helper to run write_uc_materialized_data with mocked external deps.""" + fe_module = MagicMock() + fe_module.FeatureEngineeringClient = MagicMock() + with patch.dict("sys.modules", {"databricks.feature_engineering": fe_module}): + with patch( + "feast.infra.offline_stores.contrib.spark_offline_store.uc_registration.get_databricks_session" + ) as mock_get_session: + fe_client = MagicMock() + fe_module.FeatureEngineeringClient.return_value = fe_client + if mock_spark_session is None: + mock_session = MagicMock() + else: + mock_session = mock_spark_session + mock_session.catalog.tableExists.return_value = table_exists + mock_get_session.return_value = mock_session + write_uc_materialized_data(config, fv, df, project) + return fe_client, mock_session + + +def test_write_uc_materialized_data_skips_wrong_offline_store(): + config = RepoConfig( + registry="file:///tmp/reg.db", + project="test", + provider="local", + offline_store=MagicMock(), + ) + fv = MagicMock(spec=FeatureView) + write_uc_materialized_data(config, fv, MagicMock(), "test") + + +def test_write_uc_materialized_data_skips_when_disabled(): + config = RepoConfig( + registry="file:///tmp/reg.db", + project="test", + provider="local", + offline_store=DatabricksUCOfflineStoreConfig( + type="databricks_uc", + uc_registration=UCRegistrationConfig(enabled=False), + ), + ) + fv = MagicMock(spec=FeatureView) + write_uc_materialized_data(config, fv, MagicMock(), "test") + + +def test_write_uc_materialized_data_skips_opt_out(): + config = RepoConfig( + registry="file:///tmp/reg.db", + project="test", + provider="local", + offline_store=DatabricksUCOfflineStoreConfig( + type="databricks_uc", + default_catalog="cat", + default_schema="sch", + uc_registration=UCRegistrationConfig(enabled=True), + ), + ) + fv = MagicMock(spec=FeatureView) + fv.tags = {"uc.register_as_feature_table": "false"} + fv.name = "opt_out_fv" + + fe_module = MagicMock() + with patch.dict("sys.modules", {"databricks.feature_engineering": fe_module}): + write_uc_materialized_data(config, fv, MagicMock(), "test") + + fe_module.FeatureEngineeringClient.assert_not_called() + + +def test_write_uc_materialized_data_skips_materialize_offline_false(): + config = RepoConfig( + registry="file:///tmp/reg.db", + project="test", + provider="local", + offline_store=DatabricksUCOfflineStoreConfig( + type="databricks_uc", + default_catalog="cat", + default_schema="sch", + uc_registration=UCRegistrationConfig(enabled=True), + ), + ) + fv = MagicMock(spec=FeatureView) + fv.tags = {"uc.materialize_offline": "false"} + fv.name = "skip_mat_fv" + + fe_module = MagicMock() + with patch.dict("sys.modules", {"databricks.feature_engineering": fe_module}): + write_uc_materialized_data(config, fv, MagicMock(), "test") + + fe_module.FeatureEngineeringClient.assert_not_called() + + +def test_write_uc_materialized_data_skips_missing_catalog(): + config = RepoConfig( + registry="file:///tmp/reg.db", + project="test", + provider="local", + offline_store=DatabricksUCOfflineStoreConfig( + type="databricks_uc", + default_catalog=None, + default_schema="sch", + uc_registration=UCRegistrationConfig(enabled=True), + ), + ) + fv = MagicMock(spec=FeatureView) + fv.tags = {} + fv.name = "test_fv" + + fe_module = MagicMock() + with patch.dict("sys.modules", {"databricks.feature_engineering": fe_module}): + write_uc_materialized_data(config, fv, MagicMock(), "test") + + fe_module.FeatureEngineeringClient.assert_not_called() + + +def test_write_uc_materialized_data_writes_merge(): + config = RepoConfig( + registry="file:///tmp/reg.db", + project="test", + provider="local", + offline_store=DatabricksUCOfflineStoreConfig( + type="databricks_uc", + default_catalog="cat", + default_schema="sch", + uc_registration=UCRegistrationConfig(enabled=True), + ), + ) + fv = MagicMock(spec=FeatureView) + fv.tags = {} + fv.name = "test_fv" + fv.entity_columns = [] + fv.features = [] + fv.batch_source = MagicMock() + fv.batch_source.timestamp_field = None + fv.description = "" + fv.owner = None + + fe_client, _ = _run_write_with_mock( + config, fv, MagicMock(), "test", table_exists=True + ) + + fe_client.write_table.assert_called_once() + call_args = fe_client.write_table.call_args[1] + assert call_args["name"] == "`cat`.`sch`.`test_fv`" + assert call_args["mode"] == "merge" + + +def test_write_uc_materialized_data_creates_then_writes(): + config = RepoConfig( + registry="file:///tmp/reg.db", + project="test", + provider="local", + offline_store=DatabricksUCOfflineStoreConfig( + type="databricks_uc", + default_catalog="cat", + default_schema="sch", + uc_registration=UCRegistrationConfig(enabled=True), + ), + ) + fv = MagicMock(spec=FeatureView) + fv.tags = {} + fv.name = "new_fv" + fv.entity_columns = [] + fv.features = [] + fv.batch_source = MagicMock() + fv.batch_source.timestamp_field = None + fv.description = "" + fv.owner = None + + fe_client, _ = _run_write_with_mock( + config, fv, MagicMock(), "test", table_exists=False + ) + + fe_client.create_table.assert_called_once() + fe_client.write_table.assert_called_once() + assert fe_client.write_table.call_args[1]["mode"] == "merge"