Skip to content

extractEmbedding

Short summary of the script Extracts latent embeddings from images using a pre-trained autoencoder model.

This function handles both single-channel and multi-channel datasets, loading the specified
model and processing the dataset directory's images to generate embeddings. The resulting
embeddings are saved to the designated output path.

The `extract_embedding` function extract and save the embeddings.
The `extract3D_embedding` function extract and save the embeddings from 3D models.

extract3DEmbedding(ModelPath, dataset_dir, embeddingSavePath, embeddingSaveName='latent_embeddings.csv', embedding_size=256, batch_size=8, multi=False, channels=None)

Extract embedding from the specified images and trained 3D model.

Parameters:

ModelPath (str):
    Path to the saved trained 3D model.

dataset_dir (str):
    The file path leading to the directory that holds the images data.

embeddingSavePath (str):
    Path to output directory for saving the extracted embedding file.

embeddingSaveName (str):
    FileName for saving the extracted embedding file.

embedding_size (int, optional):
    encoding dimension for expected extracting embedding.

batch_size (int, optional):
    batch size for dataloader.

multi (bool, optional):
    Multichannels or not. Default False.

channels (list, optional):
    Specified list of markers to identify the Path of each marker's folder.
    Default None. If   `multi` is True, `channels` should not be None.

Example:

```python
multi = Flase
ModelPath='/n/scratch/users/r/roh6824/Results/LSP13626_DNA_padding/SpatialAE/ln_3Dautoencoder_DNA_validate_2_model_test_update.pth'
dataset_dir="/n/scratch/users/r/roh6824/Results/LSP13626_DNA_padding/SpatialAE/Single3DPatch/DNA1/"
embeddingSavePath = "/home/roh6824/ResearchProject/SpatialMolecular/datasets/LSP13626/analysis/", 
embeddingSaveName = "LSP13626_DNA_dim256_selected_latent_embeddings.csv"

extract3DEmbedding(ModelPath,
                   dataset_dir,
                   embeddingSavePath,
                   embeddingSaveName,
                   embedding_size = 256,
                   batch_size = 8)
```
Source code in spatialae/datasets/extract.py
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
def extract3DEmbedding(ModelPath,
                       dataset_dir,
                       embeddingSavePath,
                       embeddingSaveName = "latent_embeddings.csv",
                       embedding_size = 256,
                       batch_size = 8,
                       multi = False,
                       channels = None
                       ):
    """
    Extract embedding from the specified images and trained 3D model.

Parameters:

    ModelPath (str):
        Path to the saved trained 3D model. 

    dataset_dir (str):
        The file path leading to the directory that holds the images data.

    embeddingSavePath (str):
        Path to output directory for saving the extracted embedding file. 

    embeddingSaveName (str):
        FileName for saving the extracted embedding file.

    embedding_size (int, optional):
        encoding dimension for expected extracting embedding.

    batch_size (int, optional):
        batch size for dataloader.

    multi (bool, optional):
        Multichannels or not. Default False.

    channels (list, optional):
        Specified list of markers to identify the Path of each marker's folder.
        Default None. If   `multi` is True, `channels` should not be None.

Example:

    ```python
    multi = Flase
    ModelPath='/n/scratch/users/r/roh6824/Results/LSP13626_DNA_padding/SpatialAE/ln_3Dautoencoder_DNA_validate_2_model_test_update.pth'
    dataset_dir="/n/scratch/users/r/roh6824/Results/LSP13626_DNA_padding/SpatialAE/Single3DPatch/DNA1/"
    embeddingSavePath = "/home/roh6824/ResearchProject/SpatialMolecular/datasets/LSP13626/analysis/", 
    embeddingSaveName = "LSP13626_DNA_dim256_selected_latent_embeddings.csv"

    extract3DEmbedding(ModelPath,
                       dataset_dir,
                       embeddingSavePath,
                       embeddingSaveName,
                       embedding_size = 256,
                       batch_size = 8)
    ```

    """

    if multi:
        if channels is None:
            raise ValueError("Should provide marker channels list.")
        input_channels = len(channels)
        output_channels = len(channels)
        # reoload the model
        model = LitAutoEncoder3D_Complex(input_channels, output_channels, embedding_size)
        model.load_state_dict(torch.load(ModelPath))

        transform = ToTensor3D()
        used_dataset = MultiChannelSpatial3DImageDataset(dataset_dir, channels, transform=transform)
        data_loader = DataLoader(used_dataset, batch_size=batch_size, shuffle=True)

        embedding = extract3D_embedding(data_loader, model = model, outdir = embeddingSavePath,filename = embeddingSaveName)

    else:
        input_channels = 1
        output_channels = 1
        # reoload the model
        model = LitAutoEncoder3D_Complex(input_channels, output_channels, embedding_size)
        model.load_state_dict(torch.load(ModelPath))

        transform = ToTensor3D()
        used_dataset = Spatial3DImageDataset(dataset_dir, transform=transform)
        data_loader = DataLoader(used_dataset, batch_size=batch_size, shuffle=False)

        embedding = extract3D_embedding(data_loader, model = model, outdir = embeddingSavePath,filename = embeddingSaveName)

    return embedding

extract3D_embedding(data_loader, model, outdir, filename='latent_embeddings.csv', device=None)

Example usage
extract3D_embedding(data_loader=my_data_loader, model=my_model, outdir='path/to/output/', device='cuda')
Source code in spatialae/datasets/extract.py
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
def extract3D_embedding(data_loader, model, outdir, filename="latent_embeddings.csv", device=None):
    """
    # Example usage
    # extract3D_embedding(data_loader=my_data_loader, model=my_model, outdir='path/to/output/', device='cuda')
    """
    # Ensure output directory exists
    os.makedirs(outdir, exist_ok=True)

    # Determine the device to use if not specified
    if device is None:
        device = 'cuda' if torch.cuda.is_available() else 'cpu'

    # Move model to the appropriate device
    model.to(device)

    # Initialize a flag to check if the CSV header should be written (only for the first batch)
    write_header = True
    mode = 'w'  # Write mode for the first batch

    for data in data_loader:
        inputs, filenames = data

        # Move inputs to the correct device
        inputs = inputs.to(device)

        # Forward pass to get the embeddings
        # Ensure no gradient is computed since we are only doing inference
        with torch.no_grad():
            embedding = model.encoder(inputs)

        # Move embeddings to the CPU and convert to NumPy if not already on the CPU
        embedding = embedding.detach().cpu().numpy()

        # Process filenames to use as index
        index = [f.split(".tif")[0].split("_")[-1] for f in filenames]

        # Create a DataFrame for the current batch
        df_batch = pd.DataFrame(embedding, index=index)

        # Write the current batch embeddings to a CSV file incrementally
        with open(os.path.join(outdir, filename), mode) as f:
            df_batch.to_csv(f, header=write_header)

        # Update the flag and mode after the first batch
        write_header = False
        mode = 'a'  # Append mode for subsequent batches

    print(f"Embeddings have been written to {os.path.join(outdir, filename)}")

extractEmbedding(ModelPath, dataset_dir, embeddingSavePath, embeddingSaveName='latent_embeddings.csv', input_dim=256, encoding_dim=64, batch_size=8, multi=False, channels=None)

Parameters:

Name Type Description Default
ModelPath str

The file path to the saved trained model (e.g., a .pth file for PyTorch models).

required
dataset_dir str

The directory path containing image data to process. For multi-channel data, this is the parent directory containing subdirectories for each channel.

required
embeddingSavePath str

The directory path where the resultant embeddings CSV file will be saved.

required
embeddingSaveName str

The name of the CSV file to save the extracted embeddings. Defaults to "latent_embeddings.csv".

'latent_embeddings.csv'
input_dim int

The size of the input layer of the autoencoder, generally the product of image dimensions (e.g., width * height for square images). Defaults to 256.

256
encoding_dim int

The size of the encoding layer, representing the dimensionality of the produced embeddings. Defaults to 64.

64
batch_size int

The number of images to process in a single batch during embedding extraction. Defaults to 8.

8
multi bool

A flag indicating whether the dataset is multi-channel (True) or single-channel (False). Defaults to False.

False
channels list

A list of subdirectory names within 'dataset_dir', each corresponding to a channel in multi-channel data. This parameter is required if 'multi' is True. Defaults to None.

None

Returns:

Type Description

A list or array of extracted embeddings, which also be saved to a specified CSV file.

Example:

```python
# For single-channel data

input_dim = 256  #  (16x16)
encoding_dim = 64
batch_size = 32
embeddingSavePath="/home/roh6824/ResearchProject/SpatialMolecular/"
embeddingSaveName="latent_embeddingsxxx.csv"

ModelPath = '/n/scratch/users/r/roh6824/Results/CRC12image_update/SpatialAE/ln_autoencoder_DNA_validate_300_model.pth'
dataset_dir="/n/scratch/users/r/roh6824/Results/CRC12image/SpatialAE/SinglePatch/DNA1/"
multi = Flase

extractEmbedding(ModelPath, dataset_dir,
                 embeddingSavePath,
                 embeddingSaveName,
                 input_dim = input_dim,
                 encoding_dim = encoding_dim,
                 batch_size = batch_size)

# For multi-channel data

ModelPath="/n/scratch/users/r/roh6824/Results/CRC12image_update/SpatialAE/ln_autoencoder_multi_no_DNA_validate_300_model_dim64.pth"
dataset_dir="/n/scratch/users/r/roh6824/Results/CRC12image_update/SpatialAE/SinglePatch/"
channels = ["DNA1", "CD3", "KERATIN", "CD20", "CD68","CD8A", "CD163","ECAD", "CD31"]
multi = True
extractEmbedding(ModelPath, dataset_dir,
                 embeddingSavePath,
                 embeddingSaveName,
                 input_dim = input_dim,
                 encoding_dim = encoding_dim,
                 batch_size = batch_size,
                 multi = multi,
                 channels = channels)

```
Source code in spatialae/datasets/extract.py
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
def extractEmbedding(ModelPath,
                     dataset_dir,
                     embeddingSavePath,
                     embeddingSaveName = "latent_embeddings.csv",
                     input_dim = 256,
                     encoding_dim = 64,
                     batch_size = 8,
                     multi = False,
                     channels = None
                     ):
    """  
Parameters:
    ModelPath (str):
        The file path to the saved trained model (e.g., a .pth file for PyTorch models).

    dataset_dir (str):
        The directory path containing image data to process. 
        For multi-channel data, this is the parent directory containing subdirectories for each channel.

    embeddingSavePath (str):
        The directory path where the resultant embeddings CSV file will be saved.

    embeddingSaveName (str):
        The name of the CSV file to save the extracted embeddings. Defaults to "latent_embeddings.csv".

    input_dim (int, optional):
        The size of the input layer of the autoencoder, generally the product of image dimensions 
        (e.g., width * height for square images). Defaults to 256.

    encoding_dim (int, optional):
        The size of the encoding layer, representing the dimensionality of the produced embeddings. Defaults to 64.

    batch_size (int, optional):
        The number of images to process in a single batch during embedding extraction. Defaults to 8.

    multi (bool, optional):
        A flag indicating whether the dataset is multi-channel (True) or single-channel (False). Defaults to False.

    channels (list, optional):
        A list of subdirectory names within 'dataset_dir', each corresponding to a channel in multi-channel data.
        This parameter is required if 'multi' is True. Defaults to None.

Returns:
        A list or array of extracted embeddings, which also be saved to a specified CSV file.


Example:

    ```python
    # For single-channel data

    input_dim = 256  #  (16x16)
    encoding_dim = 64
    batch_size = 32
    embeddingSavePath="/home/roh6824/ResearchProject/SpatialMolecular/"
    embeddingSaveName="latent_embeddingsxxx.csv"

    ModelPath = '/n/scratch/users/r/roh6824/Results/CRC12image_update/SpatialAE/ln_autoencoder_DNA_validate_300_model.pth'
    dataset_dir="/n/scratch/users/r/roh6824/Results/CRC12image/SpatialAE/SinglePatch/DNA1/"
    multi = Flase

    extractEmbedding(ModelPath, dataset_dir,
                     embeddingSavePath,
                     embeddingSaveName,
                     input_dim = input_dim,
                     encoding_dim = encoding_dim,
                     batch_size = batch_size)

    # For multi-channel data

    ModelPath="/n/scratch/users/r/roh6824/Results/CRC12image_update/SpatialAE/ln_autoencoder_multi_no_DNA_validate_300_model_dim64.pth"
    dataset_dir="/n/scratch/users/r/roh6824/Results/CRC12image_update/SpatialAE/SinglePatch/"
    channels = ["DNA1", "CD3", "KERATIN", "CD20", "CD68","CD8A", "CD163","ECAD", "CD31"]
    multi = True
    extractEmbedding(ModelPath, dataset_dir,
                     embeddingSavePath,
                     embeddingSaveName,
                     input_dim = input_dim,
                     encoding_dim = encoding_dim,
                     batch_size = batch_size,
                     multi = multi,
                     channels = channels)

    ```

    """

    if multi:
        if channels is None:
            raise ValueError("Should provide marker channels list.")
        used_dataset = MultiChannelImageDataset(dataset_dir, channels)
        data_loader = DataLoader(used_dataset, batch_size=batch_size, shuffle=False)
        input_dim = input_dim*len(channels)
        # reoload the model
        model = LitAutoEncoder(input_dim, encoding_dim)
        model.load_state_dict(torch.load(ModelPath))
        embedding = extract_embedding_update(data_loader, input_dim=input_dim, model = model, 
                                      outdir = embeddingSavePath,
                                      filename = embeddingSaveName)
    else:
        used_dataset = SpatialImageDataset(dataset_dir)
        data_loader = DataLoader(used_dataset, batch_size=batch_size, shuffle=False)

        # reoload the model
        model = LitAutoEncoder(input_dim, encoding_dim)
        model.load_state_dict(torch.load(ModelPath))

        embedding = extract_embedding_update(data_loader, input_dim=input_dim, model = model, 
                                      outdir = embeddingSavePath,
                                      filename = embeddingSaveName)

    return embedding

extract_embedding(data_loader, input_dim, model, outdir, filename='latent_embeddings.csv')

Source code in spatialae/datasets/extract.py
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
def extract_embedding(data_loader, input_dim, model, outdir, filename = "latent_embeddings.csv"):
    """
    """
    filenames_lst = []
    outputs_list = []
    for data in data_loader:
        inputs, filenames = data
        inputs = inputs.view(-1, input_dim)
        outputs = model.encoder(inputs)
        filenames_lst = filenames_lst + list(filenames)
        outputs_list.append(outputs)

    # Stack the tensors
    concatenated_tensors = torch.cat(outputs_list, dim=0)
    print(concatenated_tensors.shape)
    # print(len(filenames_lst))

    df = pd.DataFrame(concatenated_tensors.detach().numpy(), index = filenames_lst)
    df.index = df.index.str.split(".tif").str[0].str.split("_").str[-1]
    df.index.name = "cell"
    df.to_csv(outdir + filename)
    return concatenated_tensors

extract_embedding_update(data_loader, input_dim, model, outdir, filename='latent_embeddings.csv', device=None)

Extract embeddings from the model and write them to a CSV file.

Source code in spatialae/datasets/extract.py
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
def extract_embedding_update(data_loader, input_dim, model, outdir, filename="latent_embeddings.csv", device=None):
    """
    Extract embeddings from the model and write them to a CSV file.
    """
    # Ensure output directory exists
    os.makedirs(outdir, exist_ok=True)

    # Determine the device to use if not specified
    if device is None:
        device = 'cuda' if torch.cuda.is_available() else 'cpu'

    # Move model to the appropriate device
    model.to(device)

    # Initialize a flag to check if the CSV header should be written (only for the first batch)
    write_header = True
    mode = 'w'  # Write mode for the first batch

    for data in data_loader:
        inputs, filenames = data

        # Flatten the input if necessary and move to the correct device
        inputs = inputs.view(-1, input_dim).to(device)

        # Forward pass to get the embeddings
        # Ensure no gradient is computed since we are only doing inference
        with torch.no_grad():
            outputs = model.encoder(inputs)

        # Move the outputs to the CPU and convert to NumPy
        outputs = outputs.detach().cpu().numpy()

        # Process filenames to use as index
        index = [f.split(".tif")[0].split("_")[-1] for f in filenames]

        # Create a DataFrame for the current batch
        df_batch = pd.DataFrame(outputs, index=index)

        # Write the current batch embeddings to a CSV file incrementally
        with open(os.path.join(outdir, filename), mode) as f:
            df_batch.to_csv(f, header=write_header)

        # Update the flag and mode after the first batch
        write_header = False
        mode = 'a'  # Append mode for subsequent batches

    print(f"Embeddings have been written to {os.path.join(outdir, filename)}")