Skip to content

visualizeImages

Short Description

The function Visualizes side-by-side comparisons of original and predicted images.

visualizeImages(ModelPath, dataset_dir, searchStrList, input_dim=256, encoding_dim=64, batch_size=8, imageViewSize=16, iterNum=1, multi=False, channels=None)

Parameters:

Name Type Description Default
ModelPath str

Path to the saved trained model file.

required
dataset_dir str

Path to the directory containing the images.

required
searchStrList list

List of strings (cellID) used to match and select groups of images in dataset_dir.

required
input_dim int

The dimensions of the input image represented as one integer, assuming a square image (e.g., for 256x256, input_dim is 256). Default is 256.

256
encoding_dim int

The dimension of the encoding for the autoencoder model. Default is 64.

64
batch_size int

The number of images to process in each batch. Default is 8.

8
imageViewSize int

The size of the images to display in the visualization, assuming square images. Default is 16.

16
iterNum nt

The number of iterations or batches of images to visualize. Default is 1, which visualizes the first batch.

1
multi bool

Indicates whether the visualization is for multiple channels. Default is False. If True, channels must be provided.

False
channels list

List of channel names if multi is True, used to identify the path of each marker's folder. Default is None.

None

Returns:

Type Description

This function does not return any value. It visualizes the images by showing them directly or saving them to a file.

Raises:

Type Description
ValueError

If multi is True but channels is None.

Example

Visualize multiple channels:

import spatialae as sa   
# Define parameters
multi = True
ModelPath="/n/scratch/users/r/roh6824/Results/CRC12image_update/SpatialAE/ln_autoencoder_multi_no_DNA_validate_300_model_dim64.pth"
dataset_dir="/n/scratch/users/r/roh6824/Results/CRC12image_update/SpatialAE/SinglePatch/"
channels = ["DNA1", "CD3", "KERATIN", "CD20", "CD68","CD8A", "CD163","ECAD", "CD31"]
batch_size = 4
# Assume `adata` is a preloaded AnnData object (for getting the cells cluster)
adata = sc.read(adataPath)
select_int_list = adata.obs[adata.obs["leiden0.8"] == "8"]["cell"].values.tolist()
select_str_list = [str(i) for i in select_int_list]
# Execute function
sa.visualizeImages(ModelPath,
                dataset_dir,
                searchStrList = select_str_list,
                input_dim = 256,
                encoding_dim = 64,
                batch_size = batch_size,
                imageViewSize = 16,
                iterNum = 1,
                multi = multi,
                channels = channels)

Visualize a single channel (e.g., DNA1):

import spatialae as sa
# Define parameters
multi = False
ModelPath = '/n/scratch/users/r/roh6824/Results/CRC12image_update/SpatialAE/ln_autoencoder_DNA_validate_300_model.pth'
dataset_dir="/n/scratch/users/r/roh6824/Results/CRC12image/SpatialAE/SinglePatch/DNA1/"
batch_size = 8
# Assume `adata` is a preloaded AnnData object (for getting the cells cluster)
adata = sc.read(adataPath)
select_int_list = adata.obs[adata.obs["leiden0.4"] == "4"]["cell"].values.tolist()
select_str_list = [str(i) for i in select_int_list]
# Execute function
sa.visualizeImages(ModelPath,
                dataset_dir,
                searchStrList = select_str_list,
                input_dim = 256,
                encoding_dim = 64,
                batch_size = batch_size,
                imageViewSize = 16,
                iterNum = 1)

Source code in spatialae/plotting/visualize2d.py
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
def visualizeImages(ModelPath,
                    dataset_dir,
                    searchStrList,
                    input_dim = 256,
                    encoding_dim = 64,
                    batch_size = 8,
                    imageViewSize = 16,
                    iterNum = 1,
                    multi = False,
                    channels = None):

    """
Parameters:
    ModelPath (str):
        Path to the saved trained model file.

    dataset_dir (str):
        Path to the directory containing the images.

    searchStrList (list):
        List of strings (cellID) used to match and select groups of images in `dataset_dir`.

    input_dim (int, optional):
        The dimensions of the input image represented as one integer, assuming a square 
        image (e.g., for 256x256, `input_dim` is 256).
        Default is 256.

    encoding_dim (int, optional):
        The dimension of the encoding for the autoencoder model.
        Default is 64.

    batch_size (int, optional):
        The number of images to process in each batch.
        Default is 8.

    imageViewSize (int, optional):
        The size of the images to display in the visualization, assuming square images.
        Default is 16.

    iterNum (nt, optional):
        The number of iterations or batches of images to visualize.
        Default is 1, which visualizes the first batch.

    multi (bool, optional):
        Indicates whether the visualization is for multiple channels.
        Default is False. If True, `channels` must be provided.

    channels (list, optional):
        List of channel names if `multi` is True, used to identify the path of each marker's folder.
        Default is None.

Returns:
    This function does not return any value. It visualizes the images by showing them directly or saving them to a file.

Raises:
    ValueError: If `multi` is True but `channels` is None.

Example:
    Visualize multiple channels:
    ```python
    import spatialae as sa   
    # Define parameters
    multi = True
    ModelPath="/n/scratch/users/r/roh6824/Results/CRC12image_update/SpatialAE/ln_autoencoder_multi_no_DNA_validate_300_model_dim64.pth"
    dataset_dir="/n/scratch/users/r/roh6824/Results/CRC12image_update/SpatialAE/SinglePatch/"
    channels = ["DNA1", "CD3", "KERATIN", "CD20", "CD68","CD8A", "CD163","ECAD", "CD31"]
    batch_size = 4
    # Assume `adata` is a preloaded AnnData object (for getting the cells cluster)
    adata = sc.read(adataPath)
    select_int_list = adata.obs[adata.obs["leiden0.8"] == "8"]["cell"].values.tolist()
    select_str_list = [str(i) for i in select_int_list]
    # Execute function
    sa.visualizeImages(ModelPath,
                    dataset_dir,
                    searchStrList = select_str_list,
                    input_dim = 256,
                    encoding_dim = 64,
                    batch_size = batch_size,
                    imageViewSize = 16,
                    iterNum = 1,
                    multi = multi,
                    channels = channels)
    ```

    Visualize a single channel (e.g., DNA1):
    ```python
    import spatialae as sa
    # Define parameters
    multi = False
    ModelPath = '/n/scratch/users/r/roh6824/Results/CRC12image_update/SpatialAE/ln_autoencoder_DNA_validate_300_model.pth'
    dataset_dir="/n/scratch/users/r/roh6824/Results/CRC12image/SpatialAE/SinglePatch/DNA1/"
    batch_size = 8
    # Assume `adata` is a preloaded AnnData object (for getting the cells cluster)
    adata = sc.read(adataPath)
    select_int_list = adata.obs[adata.obs["leiden0.4"] == "4"]["cell"].values.tolist()
    select_str_list = [str(i) for i in select_int_list]
    # Execute function
    sa.visualizeImages(ModelPath,
                    dataset_dir,
                    searchStrList = select_str_list,
                    input_dim = 256,
                    encoding_dim = 64,
                    batch_size = batch_size,
                    imageViewSize = 16,
                    iterNum = 1)
    ```
    """


    if multi:
        if channels is None:
            raise ValueError("Should provide marker channels list.")

        # reoload the model
        input_dim = input_dim*len(channels)
        model = LitAutoEncoder(input_dim, encoding_dim)
        model.load_state_dict(torch.load(ModelPath))

        select_dataset = MultiChannelImageDataset(dataset_dir, channels, get_select = True, search_strings = searchStrList)
        data_loader = DataLoader(select_dataset, batch_size=batch_size, shuffle=False)

        data_iter = iter(data_loader)
        for i in range(iterNum):
            images, _ = next(data_iter)

        images = images.view(-1, input_dim)
        reconstructed_images = model(images)
        images = images.view(-1, imageViewSize, imageViewSize*len(channels))
        reconstructed_images = reconstructed_images.view(-1, imageViewSize, imageViewSize*len(channels))

        # Plot original and reconstructed images
        fig, axes = plt.subplots(nrows=2, ncols=batch_size, figsize=(2*batch_size*len(channels), 4))
        for images, row in zip([images, reconstructed_images], axes):
            for img, ax in zip(images, row):
                ax.imshow(img.detach().numpy(), cmap='gray')
                ax.get_xaxis().set_visible(False)
                ax.get_yaxis().set_visible(False)
        plt.show()

    else:
        # reoload the model
        model = LitAutoEncoder(input_dim, encoding_dim)
        model.load_state_dict(torch.load(ModelPath))

        select_dataset = SpatialImageDataset(dataset_dir, get_select = True, search_strings = searchStrList)
        data_loader = DataLoader(select_dataset, batch_size=batch_size, shuffle=False)
        data_iter = iter(data_loader)
        for i in range(iterNum):
            images, _ = next(data_iter)
        images = images.view(-1, input_dim)
        reconstructed_images = model(images)
        images = images.view(-1, imageViewSize, imageViewSize)
        reconstructed_images = reconstructed_images.view(-1, imageViewSize, imageViewSize)

        # Plot original and reconstructed images
        fig, axes = plt.subplots(nrows=2, ncols=batch_size, figsize=(2*batch_size, 4))
        for images, row in zip([images, reconstructed_images], axes):
            for img, ax in zip(images, row):
                ax.imshow(img.detach().numpy(), cmap='gray')
                ax.get_xaxis().set_visible(False)
                ax.get_yaxis().set_visible(False)
        plt.show()