Engineering Blog

Snorkeling with Snowflake

Snorkeling with Snowflake

Snorkeling with Snowflake

How to use Ray to execute massively parallel compute on a Snowflake dataset

How to use Ray to execute massively parallel compute on a Snowflake dataset

How to use Ray to execute massively parallel compute on a Snowflake dataset

Table Of Contents

Never Miss An Update

Ray and its managed offering Anyscale have been making waves now primarily because of how straightforward it makes for developers to leverage massively parallel compute with super simple idioms.

I have been using Snowflake at Toplyne for a long time and looking for ways to leverage Ray with Snowflake to unlock next-level compute power.

The Snowflake Python connector offers a basic API that lets us pull data from Snowflake in a batched manner. To date, I have been using Python multi-threading to patch together basic Snowflake workflows. A basic workflow lets us boot up multiple threads wherein individual batches of Snowflake data can be mapped to different threads. This lets us get a seeming increase in throughput.

  from concurrent.futures import ThreadPoolExecutorfrom snowflake.connector import connectconnect_args = {...}query = "select * from SNOWFLAKE_SAMPLE_DATA.TPCH_SF1.LINEITEM"def _compute(_sf_batch: ResultBatch):    arrow_table = _sf_batch.to_arrow()    # do more compute on this arrow tablewith connect(**connect_args) as conn:    with conn.cursor() as cur:        cur.execute(query)        batches = cur.get_result_batches()with ThreadPoolExecutor() as _texec:    for batch in batches:        _texec.submit(_compute, batch)_texec.shutdown()

Multithreading is built into Python and for a lot of general-purpose tooling. It so happens that we can marry Snowflake’s APIs with multithreading and patch a workflow, but this approach has limited horizontal scalability capabilities.

Let's see what this workflow will look like in Ray:

  import rayimport pyarrow as pafrom snowflake.connector import connectconnect_args = {...}query = "select * from SNOWFLAKE_SAMPLE_DATA.TPCH_SF1.LINEITEM"def _compute(arrow_table: pa.table):    # do more compute on this arrow table    passsnowflake_datasource = SnowflakeDatasource(connect_args, query)rds = ray.data.read_datasource(snowflake_datasource)rds.map_batches(_compute, batch_format="pyarrow")

This is what is so exciting about Ray. Simple idioms and maximum compute.

But wait, what is SnowflakeDatasource?

Ray’s ray-data library describes APIs to load data from different sources. The library implements a bunch of general-purpose APIs to read data from well-defined data sources. However, currently there are no APIs specifically for Snowflake.

Fortunately, implementing a data source for Snowflake is pretty straightforward.

How do we go about it though? We should collect some data points first:

  1. Ray has a guide wherein the implementation of a Mongo Datasource is described.

  2. Anyscale’s GitHub repo has a fork of Ray data which has an implementation of SnowflakeDataSource as well.

  3. Additionally, Ray has documented the block API which is fundamental to Ray’s internal data representation.

One of Ray’s standout features is that we can easily map individual Snowflake result batches to Ray’s data blocks. Based on this information, we can start with our implementation.

We need to implement two Ray classes and follow these 3 methods to get the entire thing going:

  1. ray.data.datasource.datasource.Datasource [source]
     a) create_reader [source]

  2. ray.data.datasource.datasource.Reader [source]
     a) estimate_inmemory_data_size [source]
     b) get_read_tasks [source]

Now that we know the which of this Ray-data API, we can dive deeper into the what and why of the API:

1. Reader.get_read_tasks:
a) Create a Snowflake connection.
b) Execute the query.
c) Fetch the snowflake ResultBatches.
d) Generate read tasks.

These read tasks fetch the data batch from Snowflake. Since Ray’s APIs are lazy, the memory footprint of this execution step is minimal.

2. Reader.estimate_inmemory_data_size:
Get the total size of the table as of when it’ll be loaded into memory.

For our use, I’ll infer it to be the Pyarrow table size.

3. Datasource.create_reader:
Create an instance of a reader which has implemented the above two methods.

Now that we know what & why to implement, let's get into the how. I’ll be adding more descriptions in the documentation of this code.

  from ray.data.datasource import Readerfrom ray.data.block import BlockMetadatafrom ray.data import ReadTaskfrom snowflake.connector.result_batch import ResultBatch# The reader performs the heavy liftingclass _SnowflakeDatasourceReader(Reader):    def __init__(self, connection_args: dict, query: str):        # connection info like snowflake account name & credentials.        self._connection_args = connection_args        # the query to execute.        self._query = query    # this method will be reused in both creating the read_tasks    # as well as calc    @cached_property    def _result_batches(self):        # connect with snowflake        with connect(**self._connection_args) as conn:            # get the cursor            with conn.cursor() as cur:                cur.execute(self._query)                # Get the result as batches.                # This API has a minimal memory footprint because                # the ResultBatch doesn't have any data. It only                # tells us how to pull the data and what size/schema                # to expect from this data once it lands.                # The driver hence won't have any memory footprint                # and can safely do the work of creating relevant                # Block (s) for ray.                batches = cur.get_result_batches()        return batches    def estimate_inmemory_data_size(self) -> Optional[int]:        sz = None        for batch in self._result_batches:            sz = (sz or 0) + (batch.uncompressed_size or 0)        ray_data_logger.info("Estimating in-memory data size %s", sz)        return sz    def get_read_tasks(self, parallelism: int) -> list[ReadTask]:            read_tasks = []            for batch in self._result_batches:                # Map the batch metadata to the ray block metadata.                metadata = BlockMetadata(                    num_rows=batch.rowcount,                    size_bytes=batch.uncompressed_size,                    schema=pa.schema(                        [                            pa.field(                                s.name,                                FIELD_TYPE_TO_PA_TYPE[                                    s.type_code                                ]                            )                            for s in batch.schema                        ]                    ),                    input_files=None,                    exec_stats=None                )                # create a lazy handler that will load up the                # ResultBatch in the worker and do the actual                # pull from snowflake.                _r_task = LazyReadTask(                    arrow_batch=batch,                    metadata=metadata                )                read_tasks.append(_r_task)            return read_tasks# This read task is what executes in the worker(s) and pulls the data# from snowflake and returns an PyArrow table.class LazyReadTask(ReadTask):    def __init__(self, arrow_batch: ResultBatch, metadata: BlockMetadata):        self._arrow_batch = arrow_batch        self._metadata = metadata    def _read_fn(self) -> Iterable[pa.Table]:        ray_data_logger.debug(            "Reading %s rows from Snowflake", self._metadata.num_rows        )        return [self._arrow_batch.to_arrow()]

Woah 😅, that is smooth.

Now let’s quickly tidy over the data source, which will let us juice the Ray system.

  from ray.data.block import Blockfrom ray.data.datasource import Readerfrom snowflake.connector import connectclass SnowflakeDatasource(Datasource):    def __init__(self, connection_args: dict, query: str):        self._connection_args = connection_args        self._query = query    def create_reader(self, **read_args) -> Reader:        # Yesss! This is the Reader you had just implemented.        return _SnowflakeDatasourceReader(            connection_args=self._connection_args,            query=self._query        )# This is it. You are not missing anything.# To reaffirm. This is it. You are not missing anything

That is all!

You have a Snowflake data source. The next time you want to use some Ray goodness on Snowflake, you won’t be left wanting for a fast-reading data source.

You already got it here.

Now please do the cool stuff and show it to me.

Appendix:

  1. The GitHub repo with my implementation.

  2. The Anyscale blog that motivated me: https://www.anyscale.com/blog/introducing-the-anyscale-snowflake-connector

  3. The corresponding Anyscale fork: https://github.com/anyscale/datasets-database/blob/master/python/ray/data/datasource/snowflake_datasource.py

Ray and its managed offering Anyscale have been making waves now primarily because of how straightforward it makes for developers to leverage massively parallel compute with super simple idioms.

I have been using Snowflake at Toplyne for a long time and looking for ways to leverage Ray with Snowflake to unlock next-level compute power.

The Snowflake Python connector offers a basic API that lets us pull data from Snowflake in a batched manner. To date, I have been using Python multi-threading to patch together basic Snowflake workflows. A basic workflow lets us boot up multiple threads wherein individual batches of Snowflake data can be mapped to different threads. This lets us get a seeming increase in throughput.

  from concurrent.futures import ThreadPoolExecutorfrom snowflake.connector import connectconnect_args = {...}query = "select * from SNOWFLAKE_SAMPLE_DATA.TPCH_SF1.LINEITEM"def _compute(_sf_batch: ResultBatch):    arrow_table = _sf_batch.to_arrow()    # do more compute on this arrow tablewith connect(**connect_args) as conn:    with conn.cursor() as cur:        cur.execute(query)        batches = cur.get_result_batches()with ThreadPoolExecutor() as _texec:    for batch in batches:        _texec.submit(_compute, batch)_texec.shutdown()

Multithreading is built into Python and for a lot of general-purpose tooling. It so happens that we can marry Snowflake’s APIs with multithreading and patch a workflow, but this approach has limited horizontal scalability capabilities.

Let's see what this workflow will look like in Ray:

  import rayimport pyarrow as pafrom snowflake.connector import connectconnect_args = {...}query = "select * from SNOWFLAKE_SAMPLE_DATA.TPCH_SF1.LINEITEM"def _compute(arrow_table: pa.table):    # do more compute on this arrow table    passsnowflake_datasource = SnowflakeDatasource(connect_args, query)rds = ray.data.read_datasource(snowflake_datasource)rds.map_batches(_compute, batch_format="pyarrow")

This is what is so exciting about Ray. Simple idioms and maximum compute.

But wait, what is SnowflakeDatasource?

Ray’s ray-data library describes APIs to load data from different sources. The library implements a bunch of general-purpose APIs to read data from well-defined data sources. However, currently there are no APIs specifically for Snowflake.

Fortunately, implementing a data source for Snowflake is pretty straightforward.

How do we go about it though? We should collect some data points first:

  1. Ray has a guide wherein the implementation of a Mongo Datasource is described.

  2. Anyscale’s GitHub repo has a fork of Ray data which has an implementation of SnowflakeDataSource as well.

  3. Additionally, Ray has documented the block API which is fundamental to Ray’s internal data representation.

One of Ray’s standout features is that we can easily map individual Snowflake result batches to Ray’s data blocks. Based on this information, we can start with our implementation.

We need to implement two Ray classes and follow these 3 methods to get the entire thing going:

  1. ray.data.datasource.datasource.Datasource [source]
     a) create_reader [source]

  2. ray.data.datasource.datasource.Reader [source]
     a) estimate_inmemory_data_size [source]
     b) get_read_tasks [source]

Now that we know the which of this Ray-data API, we can dive deeper into the what and why of the API:

1. Reader.get_read_tasks:
a) Create a Snowflake connection.
b) Execute the query.
c) Fetch the snowflake ResultBatches.
d) Generate read tasks.

These read tasks fetch the data batch from Snowflake. Since Ray’s APIs are lazy, the memory footprint of this execution step is minimal.

2. Reader.estimate_inmemory_data_size:
Get the total size of the table as of when it’ll be loaded into memory.

For our use, I’ll infer it to be the Pyarrow table size.

3. Datasource.create_reader:
Create an instance of a reader which has implemented the above two methods.

Now that we know what & why to implement, let's get into the how. I’ll be adding more descriptions in the documentation of this code.

  from ray.data.datasource import Readerfrom ray.data.block import BlockMetadatafrom ray.data import ReadTaskfrom snowflake.connector.result_batch import ResultBatch# The reader performs the heavy liftingclass _SnowflakeDatasourceReader(Reader):    def __init__(self, connection_args: dict, query: str):        # connection info like snowflake account name & credentials.        self._connection_args = connection_args        # the query to execute.        self._query = query    # this method will be reused in both creating the read_tasks    # as well as calc    @cached_property    def _result_batches(self):        # connect with snowflake        with connect(**self._connection_args) as conn:            # get the cursor            with conn.cursor() as cur:                cur.execute(self._query)                # Get the result as batches.                # This API has a minimal memory footprint because                # the ResultBatch doesn't have any data. It only                # tells us how to pull the data and what size/schema                # to expect from this data once it lands.                # The driver hence won't have any memory footprint                # and can safely do the work of creating relevant                # Block (s) for ray.                batches = cur.get_result_batches()        return batches    def estimate_inmemory_data_size(self) -> Optional[int]:        sz = None        for batch in self._result_batches:            sz = (sz or 0) + (batch.uncompressed_size or 0)        ray_data_logger.info("Estimating in-memory data size %s", sz)        return sz    def get_read_tasks(self, parallelism: int) -> list[ReadTask]:            read_tasks = []            for batch in self._result_batches:                # Map the batch metadata to the ray block metadata.                metadata = BlockMetadata(                    num_rows=batch.rowcount,                    size_bytes=batch.uncompressed_size,                    schema=pa.schema(                        [                            pa.field(                                s.name,                                FIELD_TYPE_TO_PA_TYPE[                                    s.type_code                                ]                            )                            for s in batch.schema                        ]                    ),                    input_files=None,                    exec_stats=None                )                # create a lazy handler that will load up the                # ResultBatch in the worker and do the actual                # pull from snowflake.                _r_task = LazyReadTask(                    arrow_batch=batch,                    metadata=metadata                )                read_tasks.append(_r_task)            return read_tasks# This read task is what executes in the worker(s) and pulls the data# from snowflake and returns an PyArrow table.class LazyReadTask(ReadTask):    def __init__(self, arrow_batch: ResultBatch, metadata: BlockMetadata):        self._arrow_batch = arrow_batch        self._metadata = metadata    def _read_fn(self) -> Iterable[pa.Table]:        ray_data_logger.debug(            "Reading %s rows from Snowflake", self._metadata.num_rows        )        return [self._arrow_batch.to_arrow()]

Woah 😅, that is smooth.

Now let’s quickly tidy over the data source, which will let us juice the Ray system.

  from ray.data.block import Blockfrom ray.data.datasource import Readerfrom snowflake.connector import connectclass SnowflakeDatasource(Datasource):    def __init__(self, connection_args: dict, query: str):        self._connection_args = connection_args        self._query = query    def create_reader(self, **read_args) -> Reader:        # Yesss! This is the Reader you had just implemented.        return _SnowflakeDatasourceReader(            connection_args=self._connection_args,            query=self._query        )# This is it. You are not missing anything.# To reaffirm. This is it. You are not missing anything

That is all!

You have a Snowflake data source. The next time you want to use some Ray goodness on Snowflake, you won’t be left wanting for a fast-reading data source.

You already got it here.

Now please do the cool stuff and show it to me.

Appendix:

  1. The GitHub repo with my implementation.

  2. The Anyscale blog that motivated me: https://www.anyscale.com/blog/introducing-the-anyscale-snowflake-connector

  3. The corresponding Anyscale fork: https://github.com/anyscale/datasets-database/blob/master/python/ray/data/datasource/snowflake_datasource.py