Skip to content

Image

Captioning

AutoModelForImageCaptioning

Bases: BaseImageCaptioning

AutoModelForImageCaptioning pipeline supporting different combinations of annotation and validation models OpenAI, Gemini, and Qwen2-VL.

Example Usage:

from swiftannotate.image import AutoModelForImageCaptioning

# Initialize the pipeline
# Note: You can use either Qwen2VL, OpenAI, and Gemini for captioning and validation.
captioner = AutoModelForImageCaptioning(
    caption_model="gpt-4o",
    validation_model="gemini-1.5-flash",
    caption_api_key="your_openai_api_key",
    validation_api_key="your_gemini_api_key",
    output_file="captions.json"
)

# Generate captions for a list of images
image_paths = ["path/to/image1.jpg"]
results = captioner.generate(image_paths)

# Print results
# Output: [
#     {
#         'image_path': 'path/to/image1.jpg',
#         'image_caption': 'A cat sitting on a table.',
#         'validation_reasoning': 'The caption is valid.',
#         'validation_score': 0.8
#     },
# ]

Source code in swiftannotate/image/captioning/auto.py
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
class AutoModelForImageCaptioning(BaseImageCaptioning):
    """
    AutoModelForImageCaptioning pipeline supporting different combinations of annotation and validation models
    OpenAI, Gemini, and Qwen2-VL.

    Example Usage:
    ```python
    from swiftannotate.image import AutoModelForImageCaptioning

    # Initialize the pipeline
    # Note: You can use either Qwen2VL, OpenAI, and Gemini for captioning and validation.
    captioner = AutoModelForImageCaptioning(
        caption_model="gpt-4o",
        validation_model="gemini-1.5-flash",
        caption_api_key="your_openai_api_key",
        validation_api_key="your_gemini_api_key",
        output_file="captions.json"
    )

    # Generate captions for a list of images
    image_paths = ["path/to/image1.jpg"]
    results = captioner.generate(image_paths)

    # Print results
    # Output: [
    #     {
    #         'image_path': 'path/to/image1.jpg',
    #         'image_caption': 'A cat sitting on a table.',
    #         'validation_reasoning': 'The caption is valid.',
    #         'validation_score': 0.8
    #     },
    # ]
    ```
    """

    SUPPORTED_MODELS = {
        "openai": ["gpt-4o", "gpt-4o-mini"],
        "gemini": ["gemini-1.5-flash", "gemini-1.5-pro", "gemini-2.0-flash-exp", "gemini-1.5-flash-8b"],
        "local": [Qwen2VLForConditionalGeneration]
    }

    def __init__(
        self, 
        caption_model: str | Qwen2VLForConditionalGeneration, 
        validation_model: str | Qwen2VLForConditionalGeneration,
        caption_model_processor: Qwen2VLProcessor | None = None,
        validation_model_processor: Qwen2VLProcessor | None = None,
        caption_api_key: str | None = None, 
        validation_api_key: str | None = None,
        caption_prompt: str = BASE_IMAGE_CAPTION_PROMPT, 
        validation: bool = True,
        validation_prompt: str = BASE_IMAGE_CAPTION_VALIDATION_PROMPT,
        validation_threshold: float = 0.5,
        max_retry: int = 3, 
        output_file: str | None = None,
        **kwargs
    ):
        """
        Initialize the AutoModelForImageCaptioning class.
        This class provides functionality for automatic image captioning with optional validation.
        It supports different combinations of annotation and validation models like OpenAI, Gemini, and Qwen2-VL.

        Args:
            caption_model (Union[str, Qwen2VLForConditionalGeneration]): 
                Model or API endpoint for caption generation.
                Can be either a local model instance or API endpoint string.
            validation_model (Union[str, Qwen2VLForConditionalGeneration]): 
                Model or API endpoint for caption validation.
                Can be either a local model instance or API endpoint string.
            caption_model_processor (Optional[Qwen2VLProcessor]): 
                Processor for caption model. 
                Required if using a local model for captioning.
            validation_model_processor (Optional[Qwen2VLProcessor]): 
                Processor for validation model.
                Required if using a local model for validation.
            caption_api_key (Optional[str]): 
                API key for caption service if using API endpoint.
            validation_api_key (Optional[str]): 
                API key for validation service if using API endpoint.
            caption_prompt (str): 
                Prompt template for caption generation.
                Defaults to BASE_IMAGE_CAPTION_PROMPT.
            validation (bool): 
                Whether to perform validation on generated captions.
                Defaults to True.
            validation_prompt (str): 
                Prompt template for caption validation.
                Defaults to BASE_IMAGE_CAPTION_VALIDATION_PROMPT.
            validation_threshold (float): 
                Threshold score for caption validation.
                Defaults to 0.5.
            max_retry (int): 
                Maximum number of retry attempts for failed validation.
                Defaults to 3.
            output_file (Optional[str]): 
                Path to save results.
                If None, results are not saved.
            **kwargs: Additional arguments passed to model initialization.

        Raises:
            ValueError: If required model processors are not provided for local models.
            ValueError: If an unsupported model is provided.

        Note:
            At least one of caption_model_processor or caption_api_key must be provided for caption generation.
            Same applies for validation_model_processor or validation_api_key if validation is enabled.
        """
        super().__init__(
            caption_prompt=caption_prompt,
            validation=validation,
            validation_prompt=validation_prompt,
            validation_threshold=validation_threshold,
            max_retry=max_retry,
            output_file=output_file
        )

        self.caption_model, self.caption_model_processor, self.caption_model_type = self._initialize_model(
            caption_model, caption_model_processor, caption_api_key, "caption", **kwargs
        )

        self.validation_model, self.validation_model_processor, self.validation_model_type = self._initialize_model(
            validation_model, validation_model_processor, validation_api_key, "validation", **kwargs
        )

    def _initialize_model(self, model, processor, api_key, stage, **kwargs):
        """Initialize model based on type."""
        if isinstance(model, str):
            if model in self.SUPPORTED_MODELS["openai"]:
                self.detail = kwargs.get("detail", "low")
                self.client = OpenAI(api_key)
                return model, None, "openai"
            elif model in self.SUPPORTED_MODELS["gemini"]:
                return genai.GenerativeModel(model_name=model), None, "gemini"
            else:
                raise ValueError(f"Unsupported model: {model}")

        elif isinstance(model, Qwen2VLForConditionalGeneration):
            if processor is None:
                raise ValueError(f"Processor is required for Qwen2VL model in {stage} stage")
            self.resize_height = kwargs.get("resize_height", 280)
            self.resize_width = kwargs.get("resize_width", 420)
            return model, processor, "qwen"

        raise ValueError(f"Invalid model type for {stage} stage")

    def _openai_inference(self, messages: List[str], **kwargs):
        """Inference for OpenAI model."""

        try:
            response = self.client.chat.completions.create(
                model=self.caption_model,
                messages=messages,
                **kwargs
            )
            image_caption = response.choices[0].message.content.strip()

        except Exception as e:
            logger.error(f"Image captioning failed: {e}")
            image_caption = "ERROR"

        return image_caption

    def _gemini_inference(self, messages: List[str], stage: str, **kwargs):
        """Inference for Gemini model."""

        if stage == "annotate":
            try:
                image_caption = self.caption_model.generate_content(
                    messages,
                    generation_config=genai.GenerationConfig(
                        **kwargs
                    )
                )
            except Exception as e:
                logger.error(f"Image captioning failed: {e}")
                image_caption = "ERROR"

            return image_caption
        else:
            try:
                validation_output = self.validation_model.generate_content(
                    messages,
                    generation_config=genai.GenerationConfig(
                        response_mime_type="application/json", 
                        response_schema=ImageValidationOutputGemini
                    )
                )
                validation_reasoning = validation_output["validation_reasoning"]
                confidence = validation_output["confidence"]
            except Exception as e:
                logger.error(f"Image caption validation failed: {e}")
                validation_reasoning = "ERROR"
                confidence = 0.0

            return validation_reasoning, confidence

    def _qwen_inference(self, model: Qwen2VLForConditionalGeneration, processor: Qwen2VLProcessor, messages: List[Dict], stage: str, **kwargs):
        """Inference for Qwen2VL model."""

        text = processor.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )

        image_inputs, video_inputs = process_vision_info(messages)
        inputs = processor(
            text=[text],
            images=image_inputs,
            videos=video_inputs,
            padding=True,
            return_tensors="pt",
        )
        inputs = inputs.to(model.device)

        # Inference: Generation of the output
        if "max_new_tokens" not in kwargs:
            kwargs["max_new_tokens"] = 512

        generated_ids = model.generate(**inputs, **kwargs)
        generated_ids_trimmed = [
            out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
        ]

        if stage == "annotate":
            image_caption = processor.batch_decode(
                generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
            )[0]

            return image_caption
        else:
            validation_output = processor.batch_decode(
                generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
            )[0]

            # TODO: Need a better way to parse the output
            try:
                validation_output = validation_output.replace('```', '').replace('json', '')
                validation_output = json.loads(validation_output)
                validation_reasoning = validation_output["validation_reasoning"]
                confidence = validation_output["confidence"]
            except Exception as e:
                logger.error(f"Image caption validation parsing failed trying to parse using another logic.")

                number_str  = ''.join((ch if ch in '0123456789.-e' else ' ') for ch in validation_output)
                number_str = [i for i in number_str.split() if i.isalnum()]
                potential_confidence_scores = [float(i) for i in number_str if float(i) >= 0 and float(i) <= 1]
                confidence = max(potential_confidence_scores) if potential_confidence_scores else 0.0
                validation_reasoning = validation_output

            return validation_reasoning, confidence

    def annotate(self, image: str, feedback_prompt: str, **kwargs) -> str:
        """
        Annotates the image with a caption. Implements the logic to generate captions for an image.

        **Note**: The feedback_prompt is dynamically updated using the validation reasoning from 
        the previous iteration in case the caption does not pass validation threshold.

        Args:
            image (str): Base64 encoded image.
            feedback_prompt (str, optional): Feedback prompt for the user to generate a better caption. Defaults to ''.
            **kwargs: Additional arguments to pass to the method for custom pipeline interactions. To control generation parameters for the model.

        Returns:
            str: Generated caption for the image.
        """
        if feedback_prompt:
            user_prompt = f"""
                Last time the caption you generated for this image was incorrect because of the following reasons:
                {feedback_prompt}

                Try to generate a better caption for the image.
            """
        else:
            user_prompt = "Describe the given image."

        if self.caption_model_type == "openai":
            messages=[
                {"role": "system", "content": self.caption_prompt},
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "image_url",
                            "image_url": {
                                "url": f"data:image/jpeg;base64,{image}",
                                "detail": self.detail
                            },
                        },
                        {"type": "text", "text": user_prompt},
                    ]
                }
            ]

            caption = self._openai_inference(messages, "annotate", **kwargs)

        elif self.caption_model_type == "gemini":
            messages = [
                self.caption_prompt,
                {"mime_type": "image/jpeg", "data": image},
                user_prompt,
            ]

            caption = self._gemini_inference(messages, "annotate", **kwargs)

        else:
            messages = [
                {"role": "system", "content": self.caption_prompt},
                {
                    "role": "user", 
                    "content": [
                        {
                            "type": "image", 
                            "image": f"data:image;base64,{image}",
                            "resized_height": self.resize_height,
                            "resized_width": self.resize_width,
                        },
                        {"type": "text", "text": user_prompt},
                    ],
                },
            ]

            caption = self._qwen_inference(self.caption_model, self.caption_model_processor , messages, "annotate", **kwargs)

        return caption


    def validate(self, image: str, caption: str, **kwargs) -> Tuple[bool, float]:
        """
        Validates the caption generated for the image.

        Args:
            image (str): Base64 encoded image.
            caption (str): Caption generated for the image.

        Returns:
            Tuple[str, float]: Validation reasoning and confidence score for the caption.
        """
        if caption == "ERROR":
            return "ERROR", 0

        if self.caption_model_type == "openai":
            messages = [
                {
                    "role": "system",
                    "content": self.validation_prompt
                },
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "image_url",
                            "image_url": {
                                "url": f"data:image/jpeg;base64,{image}",
                                "detail": self.detail
                            },
                        },
                        {
                            "type": "text",
                            "text": caption + "\nValidate the caption generated for the given image."
                        }
                    ]
                }
            ]  
            validation_reasoning, confidence = self._openai_inference(messages, "validate", **kwargs)  
        elif self.caption_model_type == "gemini":
            messages = [
                self.validation_prompt,
                {'mime_type':'image/jpeg', 'data': image},
                caption,
                "Validate the caption generated for the given image."
            ]
            validation_reasoning, confidence = self._gemini_inference(messages, "validate", **kwargs)
        else:
            messages = [
                {"role": "system", "content": self.validation_prompt},
                {
                    "role": "user", 
                    "content": [
                        {
                            "type": "image", 
                            "image": f"data:image;base64,{image}",
                            "resized_height": self.resize_height,
                            "resized_width": self.resize_width,
                        },
                        {"type": "text", "text": caption},
                        {
                            "type": "text", 
                            "text": """
                            Validate the caption generated for the given image. 
                            Return output as a JSON object with keys as 'validation_reasoning' and 'confidence'.
                            """
                        },
                    ],
                },
            ]
            validation_reasoning, confidence = self._qwen_inference(self.validation_model, self.validation_model_processor, messages, "validate", **kwargs)

        return validation_reasoning, confidence

    def generate(self, image_paths, **kwargs):
        """
        Generates captions for a list of images. Implements the logic to generate captions for a list of images.

        Args:
            image_paths (List[str]): List of image paths to generate captions for.
            **kwargs: Additional arguments to pass to the method for custom pipeline interactions. To control generation parameters for the model.

        Returns:
            List[Dict]: List of captions, validation reasoning and confidence scores for each image.
        """
        results = super().generate(
            image_paths=image_paths, 
            **kwargs
        )

        return results

__init__(caption_model, validation_model, caption_model_processor=None, validation_model_processor=None, caption_api_key=None, validation_api_key=None, caption_prompt=BASE_IMAGE_CAPTION_PROMPT, validation=True, validation_prompt=BASE_IMAGE_CAPTION_VALIDATION_PROMPT, validation_threshold=0.5, max_retry=3, output_file=None, **kwargs)

Initialize the AutoModelForImageCaptioning class. This class provides functionality for automatic image captioning with optional validation. It supports different combinations of annotation and validation models like OpenAI, Gemini, and Qwen2-VL.

Parameters:

Name Type Description Default
caption_model Union[str, Qwen2VLForConditionalGeneration]

Model or API endpoint for caption generation. Can be either a local model instance or API endpoint string.

required
validation_model Union[str, Qwen2VLForConditionalGeneration]

Model or API endpoint for caption validation. Can be either a local model instance or API endpoint string.

required
caption_model_processor Optional[Qwen2VLProcessor]

Processor for caption model. Required if using a local model for captioning.

None
validation_model_processor Optional[Qwen2VLProcessor]

Processor for validation model. Required if using a local model for validation.

None
caption_api_key Optional[str]

API key for caption service if using API endpoint.

None
validation_api_key Optional[str]

API key for validation service if using API endpoint.

None
caption_prompt str

Prompt template for caption generation. Defaults to BASE_IMAGE_CAPTION_PROMPT.

BASE_IMAGE_CAPTION_PROMPT
validation bool

Whether to perform validation on generated captions. Defaults to True.

True
validation_prompt str

Prompt template for caption validation. Defaults to BASE_IMAGE_CAPTION_VALIDATION_PROMPT.

BASE_IMAGE_CAPTION_VALIDATION_PROMPT
validation_threshold float

Threshold score for caption validation. Defaults to 0.5.

0.5
max_retry int

Maximum number of retry attempts for failed validation. Defaults to 3.

3
output_file Optional[str]

Path to save results. If None, results are not saved.

None
**kwargs

Additional arguments passed to model initialization.

{}

Raises:

Type Description
ValueError

If required model processors are not provided for local models.

ValueError

If an unsupported model is provided.

Note

At least one of caption_model_processor or caption_api_key must be provided for caption generation. Same applies for validation_model_processor or validation_api_key if validation is enabled.

Source code in swiftannotate/image/captioning/auto.py
def __init__(
    self, 
    caption_model: str | Qwen2VLForConditionalGeneration, 
    validation_model: str | Qwen2VLForConditionalGeneration,
    caption_model_processor: Qwen2VLProcessor | None = None,
    validation_model_processor: Qwen2VLProcessor | None = None,
    caption_api_key: str | None = None, 
    validation_api_key: str | None = None,
    caption_prompt: str = BASE_IMAGE_CAPTION_PROMPT, 
    validation: bool = True,
    validation_prompt: str = BASE_IMAGE_CAPTION_VALIDATION_PROMPT,
    validation_threshold: float = 0.5,
    max_retry: int = 3, 
    output_file: str | None = None,
    **kwargs
):
    """
    Initialize the AutoModelForImageCaptioning class.
    This class provides functionality for automatic image captioning with optional validation.
    It supports different combinations of annotation and validation models like OpenAI, Gemini, and Qwen2-VL.

    Args:
        caption_model (Union[str, Qwen2VLForConditionalGeneration]): 
            Model or API endpoint for caption generation.
            Can be either a local model instance or API endpoint string.
        validation_model (Union[str, Qwen2VLForConditionalGeneration]): 
            Model or API endpoint for caption validation.
            Can be either a local model instance or API endpoint string.
        caption_model_processor (Optional[Qwen2VLProcessor]): 
            Processor for caption model. 
            Required if using a local model for captioning.
        validation_model_processor (Optional[Qwen2VLProcessor]): 
            Processor for validation model.
            Required if using a local model for validation.
        caption_api_key (Optional[str]): 
            API key for caption service if using API endpoint.
        validation_api_key (Optional[str]): 
            API key for validation service if using API endpoint.
        caption_prompt (str): 
            Prompt template for caption generation.
            Defaults to BASE_IMAGE_CAPTION_PROMPT.
        validation (bool): 
            Whether to perform validation on generated captions.
            Defaults to True.
        validation_prompt (str): 
            Prompt template for caption validation.
            Defaults to BASE_IMAGE_CAPTION_VALIDATION_PROMPT.
        validation_threshold (float): 
            Threshold score for caption validation.
            Defaults to 0.5.
        max_retry (int): 
            Maximum number of retry attempts for failed validation.
            Defaults to 3.
        output_file (Optional[str]): 
            Path to save results.
            If None, results are not saved.
        **kwargs: Additional arguments passed to model initialization.

    Raises:
        ValueError: If required model processors are not provided for local models.
        ValueError: If an unsupported model is provided.

    Note:
        At least one of caption_model_processor or caption_api_key must be provided for caption generation.
        Same applies for validation_model_processor or validation_api_key if validation is enabled.
    """
    super().__init__(
        caption_prompt=caption_prompt,
        validation=validation,
        validation_prompt=validation_prompt,
        validation_threshold=validation_threshold,
        max_retry=max_retry,
        output_file=output_file
    )

    self.caption_model, self.caption_model_processor, self.caption_model_type = self._initialize_model(
        caption_model, caption_model_processor, caption_api_key, "caption", **kwargs
    )

    self.validation_model, self.validation_model_processor, self.validation_model_type = self._initialize_model(
        validation_model, validation_model_processor, validation_api_key, "validation", **kwargs
    )

annotate(image, feedback_prompt, **kwargs)

Annotates the image with a caption. Implements the logic to generate captions for an image.

Note: The feedback_prompt is dynamically updated using the validation reasoning from the previous iteration in case the caption does not pass validation threshold.

Parameters:

Name Type Description Default
image str

Base64 encoded image.

required
feedback_prompt str

Feedback prompt for the user to generate a better caption. Defaults to ''.

required
**kwargs

Additional arguments to pass to the method for custom pipeline interactions. To control generation parameters for the model.

{}

Returns:

Name Type Description
str str

Generated caption for the image.

Source code in swiftannotate/image/captioning/auto.py
def annotate(self, image: str, feedback_prompt: str, **kwargs) -> str:
    """
    Annotates the image with a caption. Implements the logic to generate captions for an image.

    **Note**: The feedback_prompt is dynamically updated using the validation reasoning from 
    the previous iteration in case the caption does not pass validation threshold.

    Args:
        image (str): Base64 encoded image.
        feedback_prompt (str, optional): Feedback prompt for the user to generate a better caption. Defaults to ''.
        **kwargs: Additional arguments to pass to the method for custom pipeline interactions. To control generation parameters for the model.

    Returns:
        str: Generated caption for the image.
    """
    if feedback_prompt:
        user_prompt = f"""
            Last time the caption you generated for this image was incorrect because of the following reasons:
            {feedback_prompt}

            Try to generate a better caption for the image.
        """
    else:
        user_prompt = "Describe the given image."

    if self.caption_model_type == "openai":
        messages=[
            {"role": "system", "content": self.caption_prompt},
            {
                "role": "user",
                "content": [
                    {
                        "type": "image_url",
                        "image_url": {
                            "url": f"data:image/jpeg;base64,{image}",
                            "detail": self.detail
                        },
                    },
                    {"type": "text", "text": user_prompt},
                ]
            }
        ]

        caption = self._openai_inference(messages, "annotate", **kwargs)

    elif self.caption_model_type == "gemini":
        messages = [
            self.caption_prompt,
            {"mime_type": "image/jpeg", "data": image},
            user_prompt,
        ]

        caption = self._gemini_inference(messages, "annotate", **kwargs)

    else:
        messages = [
            {"role": "system", "content": self.caption_prompt},
            {
                "role": "user", 
                "content": [
                    {
                        "type": "image", 
                        "image": f"data:image;base64,{image}",
                        "resized_height": self.resize_height,
                        "resized_width": self.resize_width,
                    },
                    {"type": "text", "text": user_prompt},
                ],
            },
        ]

        caption = self._qwen_inference(self.caption_model, self.caption_model_processor , messages, "annotate", **kwargs)

    return caption

generate(image_paths, **kwargs)

Generates captions for a list of images. Implements the logic to generate captions for a list of images.

Parameters:

Name Type Description Default
image_paths List[str]

List of image paths to generate captions for.

required
**kwargs

Additional arguments to pass to the method for custom pipeline interactions. To control generation parameters for the model.

{}

Returns:

Type Description

List[Dict]: List of captions, validation reasoning and confidence scores for each image.

Source code in swiftannotate/image/captioning/auto.py
def generate(self, image_paths, **kwargs):
    """
    Generates captions for a list of images. Implements the logic to generate captions for a list of images.

    Args:
        image_paths (List[str]): List of image paths to generate captions for.
        **kwargs: Additional arguments to pass to the method for custom pipeline interactions. To control generation parameters for the model.

    Returns:
        List[Dict]: List of captions, validation reasoning and confidence scores for each image.
    """
    results = super().generate(
        image_paths=image_paths, 
        **kwargs
    )

    return results

validate(image, caption, **kwargs)

Validates the caption generated for the image.

Parameters:

Name Type Description Default
image str

Base64 encoded image.

required
caption str

Caption generated for the image.

required

Returns:

Type Description
Tuple[bool, float]

Tuple[str, float]: Validation reasoning and confidence score for the caption.

Source code in swiftannotate/image/captioning/auto.py
def validate(self, image: str, caption: str, **kwargs) -> Tuple[bool, float]:
    """
    Validates the caption generated for the image.

    Args:
        image (str): Base64 encoded image.
        caption (str): Caption generated for the image.

    Returns:
        Tuple[str, float]: Validation reasoning and confidence score for the caption.
    """
    if caption == "ERROR":
        return "ERROR", 0

    if self.caption_model_type == "openai":
        messages = [
            {
                "role": "system",
                "content": self.validation_prompt
            },
            {
                "role": "user",
                "content": [
                    {
                        "type": "image_url",
                        "image_url": {
                            "url": f"data:image/jpeg;base64,{image}",
                            "detail": self.detail
                        },
                    },
                    {
                        "type": "text",
                        "text": caption + "\nValidate the caption generated for the given image."
                    }
                ]
            }
        ]  
        validation_reasoning, confidence = self._openai_inference(messages, "validate", **kwargs)  
    elif self.caption_model_type == "gemini":
        messages = [
            self.validation_prompt,
            {'mime_type':'image/jpeg', 'data': image},
            caption,
            "Validate the caption generated for the given image."
        ]
        validation_reasoning, confidence = self._gemini_inference(messages, "validate", **kwargs)
    else:
        messages = [
            {"role": "system", "content": self.validation_prompt},
            {
                "role": "user", 
                "content": [
                    {
                        "type": "image", 
                        "image": f"data:image;base64,{image}",
                        "resized_height": self.resize_height,
                        "resized_width": self.resize_width,
                    },
                    {"type": "text", "text": caption},
                    {
                        "type": "text", 
                        "text": """
                        Validate the caption generated for the given image. 
                        Return output as a JSON object with keys as 'validation_reasoning' and 'confidence'.
                        """
                    },
                ],
            },
        ]
        validation_reasoning, confidence = self._qwen_inference(self.validation_model, self.validation_model_processor, messages, "validate", **kwargs)

    return validation_reasoning, confidence

GeminiForImageCaptioning

Bases: BaseImageCaptioning

GeminiForImageClassification pipeline for generating captions for images using Gemini models.

Example usage:

from swiftannotate.image import GeminiForImageCaptioning

# Initialize the pipeline
captioner = GeminiForImageCaptioning(
    caption_model="gemini-1.5-pro",
    validation_model="gemini-1.5-flash",
    api_key="your_api_key_here",
    output_file="captions.json"
)

# Generate captions for a list of images
image_paths = ["path/to/image1.jpg"]
results = captioner.generate(image_paths)

# Print results
# Output: [
#     {
#         'image_path': 'path/to/image1.jpg', 
#         'image_caption': 'A cat sitting on a table.', 
#         'validation_reasoning': 'The caption is valid.', 
#         'validation_score': 0.8
#     }, 
# ]

Source code in swiftannotate/image/captioning/gemini.py
class GeminiForImageCaptioning(BaseImageCaptioning):
    """
    GeminiForImageClassification pipeline for generating captions for images using Gemini models.

    Example usage:
    ```python
    from swiftannotate.image import GeminiForImageCaptioning

    # Initialize the pipeline
    captioner = GeminiForImageCaptioning(
        caption_model="gemini-1.5-pro",
        validation_model="gemini-1.5-flash",
        api_key="your_api_key_here",
        output_file="captions.json"
    )

    # Generate captions for a list of images
    image_paths = ["path/to/image1.jpg"]
    results = captioner.generate(image_paths)

    # Print results
    # Output: [
    #     {
    #         'image_path': 'path/to/image1.jpg', 
    #         'image_caption': 'A cat sitting on a table.', 
    #         'validation_reasoning': 'The caption is valid.', 
    #         'validation_score': 0.8
    #     }, 
    # ]
    ```
    """

    def __init__(
        self, 
        caption_model: str, 
        validation_model: str,
        api_key: str, 
        caption_prompt: str = BASE_IMAGE_CAPTION_PROMPT, 
        validation: bool = True,
        validation_prompt: str = BASE_IMAGE_CAPTION_VALIDATION_PROMPT,
        validation_threshold: float = 0.5,
        max_retry: int = 3, 
        output_file: str | None = None,
    ):
        """
        Initializes the ImageCaptioningGemini pipeline.

        Args:
            caption_model (str): 
                Can be either "gemini-1.5-flash", "gemini-1.5-pro", etc. or specific versions of model supported by Gemini.
            validation_model (str): 
                Can be either "gemini-1.5-flash", "gemini-1.5-pro", etc. or specific versions of model supported by Gemini.
            api_key (str): 
                Google Gemini API key.
            caption_prompt (str | None, optional): 
                System prompt for captioning images.
                Uses default BASE_IMAGE_CAPTION_PROMPT prompt if not provided.
            validation (bool, optional): 
                Use validation step or not. Defaults to True.
            validation_prompt (str | None, optional): 
                System prompt for validating image captions should specify the range of validation score to be generated. 
                Uses default BASE_IMAGE_CAPTION_PROMPT prompt if not provided.
            validation_threshold (float, optional): 
                Threshold to determine if image caption is valid or not should be within specified range for validation score. 
                Defaults to 0.5.
            max_retry (int, optional):
                Number of retries before giving up on the image caption. 
                Defaults to 3.
            output_file (str | None, optional): 
                Output file path, only JSON is supported for now. 
                Defaults to None.

        Notes:
            `validation_prompt` should specify the rules for validating the caption and the range of validation score to be generated example (0-1).
            Your `validation_threshold` should be within this specified range.
        """        
        genai.configure(api_key=api_key)
        self.caption_model = genai.GenerativeModel(model=caption_model)
        self.validation_model = genai.GenerativeModel(model=validation_model)

        super().__init__(
            caption_prompt=caption_prompt,
            validation=validation,
            validation_prompt=validation_prompt,
            validation_threshold=validation_threshold,
            max_retry=max_retry,
            output_file=output_file
        )

    def annotate(self, image: str, feedback_prompt:str = "", **kwargs) -> str:
        """
        Annotates the image with a caption. Implements the logic to generate captions for an image.

        **Note**: The feedback_prompt is dynamically updated using the validation reasoning from 
        the previous iteration in case the caption does not pass validation threshold.

        Args:
            image (str): Base64 encoded image.
            feedback_prompt (str, optional): Feedback prompt for the user to generate a better caption. Defaults to ''.
            **kwargs: Additional arguments to pass to the method for custom pipeline interactions. To control generation parameters for the model.

        Returns:
            str: Generated caption for the image.
        """
        if feedback_prompt:
            user_prompt = f"""
                Last time the caption you generated for this image was incorrect because of the following reasons:
                {feedback_prompt}

                Try to generate a better caption for the image.
            """
        else:
            user_prompt = "Describe the given image."

        messages = [
            self.caption_prompt,
            {'mime_type':'image/jpeg', 'data': image}, 
            user_prompt
        ]

        try:
            image_caption = self.caption_model.generate_content(
                messages,
                generation_config=genai.GenerationConfig(
                    **kwargs
                )
            )
        except Exception as e:
            logger.error(f"Image captioning failed: {e}")
            image_caption = "ERROR"

        return image_caption

    def validate(self, image: str, caption: str, **kwargs) -> Tuple[str, float]:
        """
        Validates the caption generated for the image.

        Args:
            image (str): Base64 encoded image.
            caption (str): Caption generated for the image.

        Returns:
            Tuple[str, float]: Validation reasoning and confidence score for the caption.
        """
        if caption == "ERROR":
            return "ERROR", 0.0

        messages = [
            self.validation_prompt,
            {'mime_type':'image/jpeg', 'data': image},
            caption,
            "Validate the caption generated for the given image."
        ]

        try:
            validation_output = self.validation_model.generate_content(
                messages,
                generation_config=genai.GenerationConfig(
                    response_mime_type="application/json", 
                    response_schema=ImageValidationOutputGemini
                )
            )
            validation_reasoning = validation_output["validation_reasoning"]
            confidence = validation_output["confidence"]
        except Exception as e:
            logger.error(f"Image caption validation failed: {e}")
            validation_reasoning = "ERROR"
            confidence = 0.0

        return validation_reasoning, confidence

    def generate(self, image_paths: List[str], **kwargs) -> List[Dict]:
        """
        Generates captions for a list of images. Implements the logic to generate captions for a list of images.

        Args:
            image_paths (List[str]): List of image paths to generate captions for.
            **kwargs: Additional arguments to pass to the method for custom pipeline interactions. To control generation parameters for the model.

        Returns:
            List[Dict]: List of captions, validation reasoning and confidence scores for each image.
        """

        results = super().generate(
            image_paths=image_paths, 
            **kwargs
        )

        return results 

__init__(caption_model, validation_model, api_key, caption_prompt=BASE_IMAGE_CAPTION_PROMPT, validation=True, validation_prompt=BASE_IMAGE_CAPTION_VALIDATION_PROMPT, validation_threshold=0.5, max_retry=3, output_file=None)

Initializes the ImageCaptioningGemini pipeline.

Parameters:

Name Type Description Default
caption_model str

Can be either "gemini-1.5-flash", "gemini-1.5-pro", etc. or specific versions of model supported by Gemini.

required
validation_model str

Can be either "gemini-1.5-flash", "gemini-1.5-pro", etc. or specific versions of model supported by Gemini.

required
api_key str

Google Gemini API key.

required
caption_prompt str | None

System prompt for captioning images. Uses default BASE_IMAGE_CAPTION_PROMPT prompt if not provided.

BASE_IMAGE_CAPTION_PROMPT
validation bool

Use validation step or not. Defaults to True.

True
validation_prompt str | None

System prompt for validating image captions should specify the range of validation score to be generated. Uses default BASE_IMAGE_CAPTION_PROMPT prompt if not provided.

BASE_IMAGE_CAPTION_VALIDATION_PROMPT
validation_threshold float

Threshold to determine if image caption is valid or not should be within specified range for validation score. Defaults to 0.5.

0.5
max_retry int

Number of retries before giving up on the image caption. Defaults to 3.

3
output_file str | None

Output file path, only JSON is supported for now. Defaults to None.

None
Notes

validation_prompt should specify the rules for validating the caption and the range of validation score to be generated example (0-1). Your validation_threshold should be within this specified range.

Source code in swiftannotate/image/captioning/gemini.py
def __init__(
    self, 
    caption_model: str, 
    validation_model: str,
    api_key: str, 
    caption_prompt: str = BASE_IMAGE_CAPTION_PROMPT, 
    validation: bool = True,
    validation_prompt: str = BASE_IMAGE_CAPTION_VALIDATION_PROMPT,
    validation_threshold: float = 0.5,
    max_retry: int = 3, 
    output_file: str | None = None,
):
    """
    Initializes the ImageCaptioningGemini pipeline.

    Args:
        caption_model (str): 
            Can be either "gemini-1.5-flash", "gemini-1.5-pro", etc. or specific versions of model supported by Gemini.
        validation_model (str): 
            Can be either "gemini-1.5-flash", "gemini-1.5-pro", etc. or specific versions of model supported by Gemini.
        api_key (str): 
            Google Gemini API key.
        caption_prompt (str | None, optional): 
            System prompt for captioning images.
            Uses default BASE_IMAGE_CAPTION_PROMPT prompt if not provided.
        validation (bool, optional): 
            Use validation step or not. Defaults to True.
        validation_prompt (str | None, optional): 
            System prompt for validating image captions should specify the range of validation score to be generated. 
            Uses default BASE_IMAGE_CAPTION_PROMPT prompt if not provided.
        validation_threshold (float, optional): 
            Threshold to determine if image caption is valid or not should be within specified range for validation score. 
            Defaults to 0.5.
        max_retry (int, optional):
            Number of retries before giving up on the image caption. 
            Defaults to 3.
        output_file (str | None, optional): 
            Output file path, only JSON is supported for now. 
            Defaults to None.

    Notes:
        `validation_prompt` should specify the rules for validating the caption and the range of validation score to be generated example (0-1).
        Your `validation_threshold` should be within this specified range.
    """        
    genai.configure(api_key=api_key)
    self.caption_model = genai.GenerativeModel(model=caption_model)
    self.validation_model = genai.GenerativeModel(model=validation_model)

    super().__init__(
        caption_prompt=caption_prompt,
        validation=validation,
        validation_prompt=validation_prompt,
        validation_threshold=validation_threshold,
        max_retry=max_retry,
        output_file=output_file
    )

annotate(image, feedback_prompt='', **kwargs)

Annotates the image with a caption. Implements the logic to generate captions for an image.

Note: The feedback_prompt is dynamically updated using the validation reasoning from the previous iteration in case the caption does not pass validation threshold.

Parameters:

Name Type Description Default
image str

Base64 encoded image.

required
feedback_prompt str

Feedback prompt for the user to generate a better caption. Defaults to ''.

''
**kwargs

Additional arguments to pass to the method for custom pipeline interactions. To control generation parameters for the model.

{}

Returns:

Name Type Description
str str

Generated caption for the image.

Source code in swiftannotate/image/captioning/gemini.py
def annotate(self, image: str, feedback_prompt:str = "", **kwargs) -> str:
    """
    Annotates the image with a caption. Implements the logic to generate captions for an image.

    **Note**: The feedback_prompt is dynamically updated using the validation reasoning from 
    the previous iteration in case the caption does not pass validation threshold.

    Args:
        image (str): Base64 encoded image.
        feedback_prompt (str, optional): Feedback prompt for the user to generate a better caption. Defaults to ''.
        **kwargs: Additional arguments to pass to the method for custom pipeline interactions. To control generation parameters for the model.

    Returns:
        str: Generated caption for the image.
    """
    if feedback_prompt:
        user_prompt = f"""
            Last time the caption you generated for this image was incorrect because of the following reasons:
            {feedback_prompt}

            Try to generate a better caption for the image.
        """
    else:
        user_prompt = "Describe the given image."

    messages = [
        self.caption_prompt,
        {'mime_type':'image/jpeg', 'data': image}, 
        user_prompt
    ]

    try:
        image_caption = self.caption_model.generate_content(
            messages,
            generation_config=genai.GenerationConfig(
                **kwargs
            )
        )
    except Exception as e:
        logger.error(f"Image captioning failed: {e}")
        image_caption = "ERROR"

    return image_caption

generate(image_paths, **kwargs)

Generates captions for a list of images. Implements the logic to generate captions for a list of images.

Parameters:

Name Type Description Default
image_paths List[str]

List of image paths to generate captions for.

required
**kwargs

Additional arguments to pass to the method for custom pipeline interactions. To control generation parameters for the model.

{}

Returns:

Type Description
List[Dict]

List[Dict]: List of captions, validation reasoning and confidence scores for each image.

Source code in swiftannotate/image/captioning/gemini.py
def generate(self, image_paths: List[str], **kwargs) -> List[Dict]:
    """
    Generates captions for a list of images. Implements the logic to generate captions for a list of images.

    Args:
        image_paths (List[str]): List of image paths to generate captions for.
        **kwargs: Additional arguments to pass to the method for custom pipeline interactions. To control generation parameters for the model.

    Returns:
        List[Dict]: List of captions, validation reasoning and confidence scores for each image.
    """

    results = super().generate(
        image_paths=image_paths, 
        **kwargs
    )

    return results 

validate(image, caption, **kwargs)

Validates the caption generated for the image.

Parameters:

Name Type Description Default
image str

Base64 encoded image.

required
caption str

Caption generated for the image.

required

Returns:

Type Description
Tuple[str, float]

Tuple[str, float]: Validation reasoning and confidence score for the caption.

Source code in swiftannotate/image/captioning/gemini.py
def validate(self, image: str, caption: str, **kwargs) -> Tuple[str, float]:
    """
    Validates the caption generated for the image.

    Args:
        image (str): Base64 encoded image.
        caption (str): Caption generated for the image.

    Returns:
        Tuple[str, float]: Validation reasoning and confidence score for the caption.
    """
    if caption == "ERROR":
        return "ERROR", 0.0

    messages = [
        self.validation_prompt,
        {'mime_type':'image/jpeg', 'data': image},
        caption,
        "Validate the caption generated for the given image."
    ]

    try:
        validation_output = self.validation_model.generate_content(
            messages,
            generation_config=genai.GenerationConfig(
                response_mime_type="application/json", 
                response_schema=ImageValidationOutputGemini
            )
        )
        validation_reasoning = validation_output["validation_reasoning"]
        confidence = validation_output["confidence"]
    except Exception as e:
        logger.error(f"Image caption validation failed: {e}")
        validation_reasoning = "ERROR"
        confidence = 0.0

    return validation_reasoning, confidence

OllamaForImageCaptioning

Bases: BaseImageCaptioning

OllamaForImageCaptioning pipeline using Ollama API.

Example usage:

from swiftannotate.image import OllamaForImageCaptioning

# Initialize the pipeline
captioner = OllamaForImageCaptioning(
    caption_model="llama3.2-vision",
    validation_model="llama3.2-vision",
    output_file="captions.json"
)

# Generate captions for a list of images
image_paths = ["path/to/image1.jpg"]
results = captioner.generate(image_paths)

# Print results
# Output: [
#     {
#         'image_path': 'path/to/image1.jpg', 
#         'image_caption': 'A cat sitting on a table.', 
#         'validation_reasoning': 'The caption is valid.', 
#         'validation_score': 0.8
#     }, 
# ]
Source code in swiftannotate/image/captioning/ollama.py
class OllamaForImageCaptioning(BaseImageCaptioning):
    """
    OllamaForImageCaptioning pipeline using Ollama API.

    Example usage:

    ```python
    from swiftannotate.image import OllamaForImageCaptioning

    # Initialize the pipeline
    captioner = OllamaForImageCaptioning(
        caption_model="llama3.2-vision",
        validation_model="llama3.2-vision",
        output_file="captions.json"
    )

    # Generate captions for a list of images
    image_paths = ["path/to/image1.jpg"]
    results = captioner.generate(image_paths)

    # Print results
    # Output: [
    #     {
    #         'image_path': 'path/to/image1.jpg', 
    #         'image_caption': 'A cat sitting on a table.', 
    #         'validation_reasoning': 'The caption is valid.', 
    #         'validation_score': 0.8
    #     }, 
    # ]
    ```
    """

    def __init__(
        self, 
        caption_model: str, 
        validation_model: str,
        caption_prompt: str = BASE_IMAGE_CAPTION_PROMPT, 
        validation: bool = True,
        validation_prompt: str = BASE_IMAGE_CAPTION_VALIDATION_PROMPT,
        validation_threshold: float = 0.5,
        max_retry: int = 3, 
        output_file: str | None = None,
    ):
        """
        Initializes the OllamaForImageCaptioning pipeline.

        Args:
            caption_model (str): 
                Can be either any of the Multimodal (Vision) models supported by Ollama.
                specific versions of model supported by Ollama.
            validation_model (str): 
                Can be either any of the Multimodal (Vision) models supported by Ollama.
                specific versions of model supported by Ollama.
            caption_prompt (str | None, optional): 
                System prompt for captioning images.
                Uses default BASE_IMAGE_CAPTION_PROMPT prompt if not provided.
            validation (bool, optional): 
                Use validation step or not. Defaults to True.
            validation_prompt (str | None, optional): 
                System prompt for validating image captions should specify the range of validation score to be generated. 
                Uses default BASE_IMAGE_CAPTION_PROMPT prompt if not provided.
            validation_threshold (float, optional): 
                Threshold to determine if image caption is valid or not should be within specified range for validation score. 
                Defaults to 0.5.
            max_retry (int, optional):
                Number of retries before giving up on the image caption. 
                Defaults to 3.
            output_file (str | None, optional): 
                Output file path, only JSON is supported for now. 
                Defaults to None.

        Notes:
            `validation_prompt` should specify the rules for validating the caption and the range of validation score to be generated example (0-1).
            Your `validation_threshold` should be within this specified range.
        """

        if not self._validate_ollama_model(caption_model):
            raise ValueError(f"Caption model {caption_model} is not supported by Ollama.")

        if not self._validate_ollama_model(validation_model):
            raise ValueError(f"Validation model {validation_model} is not supported by Ollama.")

        self.caption_model = caption_model
        self.validation_model = validation_model

        super().__init__(
            caption_prompt=caption_prompt,
            validation=validation,
            validation_prompt=validation_prompt,
            validation_threshold=validation_threshold,
            max_retry=max_retry,
            output_file=output_file
        )

    def _validate_ollama_model(self, model: str) -> bool:
        try:
            ollama.chat(model)
        except ollama.ResponseError as e:
            logger.error(f"Error: {e.error}")
            if e.status_code == 404:
                try:
                    ollama.pull(model)
                    logger.info(f"Model {model} is now downloaded.")
                except ollama.ResponseError as e:
                    logger.error(f"Error: {e.error}")
                    logger.error(f"Model {model} could not be downloaded. Check the model name and try again.")
                    return False
            logger.info(f"Model {model} is now downloaded.")

        return True


    def annotate(self, image: str, feedback_prompt:str = "", **kwargs) -> str:        
        """
        Annotates the image with a caption. Implements the logic to generate captions for an image.

        **Note**: The feedback_prompt is dynamically updated using the validation reasoning from 
        the previous iteration in case the caption does not pass validation threshold.

        Args:
            image (str): Base64 encoded image.
            feedback_prompt (str, optional): Feedback prompt for the user to generate a better caption. Defaults to ''.
            **kwargs: Additional arguments to pass to the method for custom pipeline interactions. To control generation parameters for the model.

        Returns:
            str: Generated caption for the image.
        """
        if feedback_prompt:
            user_prompt = f"""
                Last time the caption you generated for this image was incorrect because of the following reasons:
                {feedback_prompt}

                Try to generate a better caption for the image.
            """
        else:
            user_prompt = "Describe the given image."

        messages=[
            {"role": "system", "content": self.caption_prompt},
            {
                "role": "user",
                "images": [image],
                "content": user_prompt
            }
        ]

        if not "temperature" in kwargs:
            kwargs["temperature"] = 0.0

        try:

            response = ollama.chat(
                model=self.caption_model,
                messages=messages,
                options=kwargs
            )
            image_caption = response.message.content

        except Exception as e:
            logger.error(f"Image captioning failed: {e}")
            image_caption = "ERROR"

        return image_caption

    def validate(self, image: str, caption: str, **kwargs) -> Tuple[str, float]: 
        """
        Validates the caption generated for the image.

        Args:
            image (str): Base64 encoded image.
            caption (str): Caption generated for the image.

        Returns:
            Tuple[str, float]: Validation reasoning and confidence score for the caption.
        """
        if caption == "ERROR":
            return "ERROR", 0

        messages = [
            {
                "role": "system",
                "content": self.validation_prompt
            },
            {
                "role": "user",
                "images": [image],
                "content": caption + "\nValidate the caption generated for the given image."
            }
        ]      

        if not "temperature" in kwargs:
            kwargs["temperature"] = 0.0

        try:

            response = ollama.chat(
                model=self.validation_model,
                messages=messages,
                format=ImageValidationOutputOllama.model_json_schema(),
                options=kwargs
            )

            validation_output = ImageValidationOutputOllama.model_validate_json(response.message.content)

            validation_reasoning = validation_output.validation_reasoning
            confidence = validation_output.confidence

        except Exception as e:
            logger.error(f"Image caption validation failed: {e}")
            validation_reasoning = "ERROR"
            confidence = 0

        return validation_reasoning, confidence

    def generate(self, image_paths: List[str], **kwargs) -> List[Dict]:
        """
        Generates captions for a list of images. Implements the logic to generate captions for a list of images.

        Args:
            image_paths (List[str]): List of image paths to generate captions for.
            **kwargs: Additional arguments to pass to the method for custom pipeline interactions. To control generation parameters for the model.

        Returns:
            List[Dict]: List of captions, validation reasoning and confidence scores for each image.
        """
        results = super().generate(
            image_paths=image_paths, 
            **kwargs
        )

        return results

__init__(caption_model, validation_model, caption_prompt=BASE_IMAGE_CAPTION_PROMPT, validation=True, validation_prompt=BASE_IMAGE_CAPTION_VALIDATION_PROMPT, validation_threshold=0.5, max_retry=3, output_file=None)

Initializes the OllamaForImageCaptioning pipeline.

Parameters:

Name Type Description Default
caption_model str

Can be either any of the Multimodal (Vision) models supported by Ollama. specific versions of model supported by Ollama.

required
validation_model str

Can be either any of the Multimodal (Vision) models supported by Ollama. specific versions of model supported by Ollama.

required
caption_prompt str | None

System prompt for captioning images. Uses default BASE_IMAGE_CAPTION_PROMPT prompt if not provided.

BASE_IMAGE_CAPTION_PROMPT
validation bool

Use validation step or not. Defaults to True.

True
validation_prompt str | None

System prompt for validating image captions should specify the range of validation score to be generated. Uses default BASE_IMAGE_CAPTION_PROMPT prompt if not provided.

BASE_IMAGE_CAPTION_VALIDATION_PROMPT
validation_threshold float

Threshold to determine if image caption is valid or not should be within specified range for validation score. Defaults to 0.5.

0.5
max_retry int

Number of retries before giving up on the image caption. Defaults to 3.

3
output_file str | None

Output file path, only JSON is supported for now. Defaults to None.

None
Notes

validation_prompt should specify the rules for validating the caption and the range of validation score to be generated example (0-1). Your validation_threshold should be within this specified range.

Source code in swiftannotate/image/captioning/ollama.py
def __init__(
    self, 
    caption_model: str, 
    validation_model: str,
    caption_prompt: str = BASE_IMAGE_CAPTION_PROMPT, 
    validation: bool = True,
    validation_prompt: str = BASE_IMAGE_CAPTION_VALIDATION_PROMPT,
    validation_threshold: float = 0.5,
    max_retry: int = 3, 
    output_file: str | None = None,
):
    """
    Initializes the OllamaForImageCaptioning pipeline.

    Args:
        caption_model (str): 
            Can be either any of the Multimodal (Vision) models supported by Ollama.
            specific versions of model supported by Ollama.
        validation_model (str): 
            Can be either any of the Multimodal (Vision) models supported by Ollama.
            specific versions of model supported by Ollama.
        caption_prompt (str | None, optional): 
            System prompt for captioning images.
            Uses default BASE_IMAGE_CAPTION_PROMPT prompt if not provided.
        validation (bool, optional): 
            Use validation step or not. Defaults to True.
        validation_prompt (str | None, optional): 
            System prompt for validating image captions should specify the range of validation score to be generated. 
            Uses default BASE_IMAGE_CAPTION_PROMPT prompt if not provided.
        validation_threshold (float, optional): 
            Threshold to determine if image caption is valid or not should be within specified range for validation score. 
            Defaults to 0.5.
        max_retry (int, optional):
            Number of retries before giving up on the image caption. 
            Defaults to 3.
        output_file (str | None, optional): 
            Output file path, only JSON is supported for now. 
            Defaults to None.

    Notes:
        `validation_prompt` should specify the rules for validating the caption and the range of validation score to be generated example (0-1).
        Your `validation_threshold` should be within this specified range.
    """

    if not self._validate_ollama_model(caption_model):
        raise ValueError(f"Caption model {caption_model} is not supported by Ollama.")

    if not self._validate_ollama_model(validation_model):
        raise ValueError(f"Validation model {validation_model} is not supported by Ollama.")

    self.caption_model = caption_model
    self.validation_model = validation_model

    super().__init__(
        caption_prompt=caption_prompt,
        validation=validation,
        validation_prompt=validation_prompt,
        validation_threshold=validation_threshold,
        max_retry=max_retry,
        output_file=output_file
    )

annotate(image, feedback_prompt='', **kwargs)

Annotates the image with a caption. Implements the logic to generate captions for an image.

Note: The feedback_prompt is dynamically updated using the validation reasoning from the previous iteration in case the caption does not pass validation threshold.

Parameters:

Name Type Description Default
image str

Base64 encoded image.

required
feedback_prompt str

Feedback prompt for the user to generate a better caption. Defaults to ''.

''
**kwargs

Additional arguments to pass to the method for custom pipeline interactions. To control generation parameters for the model.

{}

Returns:

Name Type Description
str str

Generated caption for the image.

Source code in swiftannotate/image/captioning/ollama.py
def annotate(self, image: str, feedback_prompt:str = "", **kwargs) -> str:        
    """
    Annotates the image with a caption. Implements the logic to generate captions for an image.

    **Note**: The feedback_prompt is dynamically updated using the validation reasoning from 
    the previous iteration in case the caption does not pass validation threshold.

    Args:
        image (str): Base64 encoded image.
        feedback_prompt (str, optional): Feedback prompt for the user to generate a better caption. Defaults to ''.
        **kwargs: Additional arguments to pass to the method for custom pipeline interactions. To control generation parameters for the model.

    Returns:
        str: Generated caption for the image.
    """
    if feedback_prompt:
        user_prompt = f"""
            Last time the caption you generated for this image was incorrect because of the following reasons:
            {feedback_prompt}

            Try to generate a better caption for the image.
        """
    else:
        user_prompt = "Describe the given image."

    messages=[
        {"role": "system", "content": self.caption_prompt},
        {
            "role": "user",
            "images": [image],
            "content": user_prompt
        }
    ]

    if not "temperature" in kwargs:
        kwargs["temperature"] = 0.0

    try:

        response = ollama.chat(
            model=self.caption_model,
            messages=messages,
            options=kwargs
        )
        image_caption = response.message.content

    except Exception as e:
        logger.error(f"Image captioning failed: {e}")
        image_caption = "ERROR"

    return image_caption

generate(image_paths, **kwargs)

Generates captions for a list of images. Implements the logic to generate captions for a list of images.

Parameters:

Name Type Description Default
image_paths List[str]

List of image paths to generate captions for.

required
**kwargs

Additional arguments to pass to the method for custom pipeline interactions. To control generation parameters for the model.

{}

Returns:

Type Description
List[Dict]

List[Dict]: List of captions, validation reasoning and confidence scores for each image.

Source code in swiftannotate/image/captioning/ollama.py
def generate(self, image_paths: List[str], **kwargs) -> List[Dict]:
    """
    Generates captions for a list of images. Implements the logic to generate captions for a list of images.

    Args:
        image_paths (List[str]): List of image paths to generate captions for.
        **kwargs: Additional arguments to pass to the method for custom pipeline interactions. To control generation parameters for the model.

    Returns:
        List[Dict]: List of captions, validation reasoning and confidence scores for each image.
    """
    results = super().generate(
        image_paths=image_paths, 
        **kwargs
    )

    return results

validate(image, caption, **kwargs)

Validates the caption generated for the image.

Parameters:

Name Type Description Default
image str

Base64 encoded image.

required
caption str

Caption generated for the image.

required

Returns:

Type Description
Tuple[str, float]

Tuple[str, float]: Validation reasoning and confidence score for the caption.

Source code in swiftannotate/image/captioning/ollama.py
def validate(self, image: str, caption: str, **kwargs) -> Tuple[str, float]: 
    """
    Validates the caption generated for the image.

    Args:
        image (str): Base64 encoded image.
        caption (str): Caption generated for the image.

    Returns:
        Tuple[str, float]: Validation reasoning and confidence score for the caption.
    """
    if caption == "ERROR":
        return "ERROR", 0

    messages = [
        {
            "role": "system",
            "content": self.validation_prompt
        },
        {
            "role": "user",
            "images": [image],
            "content": caption + "\nValidate the caption generated for the given image."
        }
    ]      

    if not "temperature" in kwargs:
        kwargs["temperature"] = 0.0

    try:

        response = ollama.chat(
            model=self.validation_model,
            messages=messages,
            format=ImageValidationOutputOllama.model_json_schema(),
            options=kwargs
        )

        validation_output = ImageValidationOutputOllama.model_validate_json(response.message.content)

        validation_reasoning = validation_output.validation_reasoning
        confidence = validation_output.confidence

    except Exception as e:
        logger.error(f"Image caption validation failed: {e}")
        validation_reasoning = "ERROR"
        confidence = 0

    return validation_reasoning, confidence

OpenAIForImageCaptioning

Bases: BaseImageCaptioning

OpenAIForImageCaptioning pipeline using OpenAI API.

Example usage:

from swiftannotate.image import OpenAIForImageCaptioning

# Initialize the pipeline
captioner = OpenAIForImageCaptioning(
    caption_model="gpt-4o",
    validation_model="gpt-4o-mini",
    api_key="your_api_key_here",
    output_file="captions.json"
)

# Generate captions for a list of images
image_paths = ["path/to/image1.jpg"]
results = captioner.generate(image_paths)

# Print results
# Output: [
#     {
#         'image_path': 'path/to/image1.jpg', 
#         'image_caption': 'A cat sitting on a table.', 
#         'validation_reasoning': 'The caption is valid.', 
#         'validation_score': 0.8
#     }, 
# ]
Source code in swiftannotate/image/captioning/openai.py
class OpenAIForImageCaptioning(BaseImageCaptioning):
    """
    OpenAIForImageCaptioning pipeline using OpenAI API.

    Example usage:

    ```python
    from swiftannotate.image import OpenAIForImageCaptioning

    # Initialize the pipeline
    captioner = OpenAIForImageCaptioning(
        caption_model="gpt-4o",
        validation_model="gpt-4o-mini",
        api_key="your_api_key_here",
        output_file="captions.json"
    )

    # Generate captions for a list of images
    image_paths = ["path/to/image1.jpg"]
    results = captioner.generate(image_paths)

    # Print results
    # Output: [
    #     {
    #         'image_path': 'path/to/image1.jpg', 
    #         'image_caption': 'A cat sitting on a table.', 
    #         'validation_reasoning': 'The caption is valid.', 
    #         'validation_score': 0.8
    #     }, 
    # ]
    ```
    """

    def __init__(
        self, 
        caption_model: str, 
        validation_model: str,
        api_key: str, 
        caption_prompt: str = BASE_IMAGE_CAPTION_PROMPT, 
        validation: bool = True,
        validation_prompt: str = BASE_IMAGE_CAPTION_VALIDATION_PROMPT,
        validation_threshold: float = 0.5,
        max_retry: int = 3, 
        output_file: str | None = None,
        **kwargs
    ):
        """
        Initializes the ImageCaptioningOpenAI pipeline.

        Args:
            caption_model (str): 
                Can be either "gpt-4o", "gpt-4o-mini", etc. or 
                specific versions of model supported by OpenAI.
            validation_model (str): 
                Can be either "gpt-4o", "gpt-4o-mini", etc. or 
                specific versions of model supported by OpenAI.
            api_key (str): OpenAI API key.
            caption_prompt (str | None, optional): 
                System prompt for captioning images.
                Uses default BASE_IMAGE_CAPTION_PROMPT prompt if not provided.
            validation (bool, optional): 
                Use validation step or not. Defaults to True.
            validation_prompt (str | None, optional): 
                System prompt for validating image captions should specify the range of validation score to be generated. 
                Uses default BASE_IMAGE_CAPTION_PROMPT prompt if not provided.
            validation_threshold (float, optional): 
                Threshold to determine if image caption is valid or not should be within specified range for validation score. 
                Defaults to 0.5.
            max_retry (int, optional):
                Number of retries before giving up on the image caption. 
                Defaults to 3.
            output_file (str | None, optional): 
                Output file path, only JSON is supported for now. 
                Defaults to None.

        Keyword Arguments:
            detail (str, optional): 
                Specific to OpenAI. Detail level of the image (Higher resolution costs more). Defaults to "low".

        Notes:
            `validation_prompt` should specify the rules for validating the caption and the range of validation score to be generated example (0-1).
            Your `validation_threshold` should be within this specified range.
        """
        self.caption_model = caption_model
        self.validation_model = validation_model
        self.client = OpenAI(api_key)

        super().__init__(
            caption_prompt=caption_prompt,
            validation=validation,
            validation_prompt=validation_prompt,
            validation_threshold=validation_threshold,
            max_retry=max_retry,
            output_file=output_file
        )

        self.detail = kwargs.get("detail", "low")

    def annotate(self, image: str, feedback_prompt:str = "", **kwargs) -> str:        
        """
        Annotates the image with a caption. Implements the logic to generate captions for an image.

        **Note**: The feedback_prompt is dynamically updated using the validation reasoning from 
        the previous iteration in case the caption does not pass validation threshold.

        Args:
            image (str): Base64 encoded image.
            feedback_prompt (str, optional): Feedback prompt for the user to generate a better caption. Defaults to ''.
            **kwargs: Additional arguments to pass to the method for custom pipeline interactions. To control generation parameters for the model.

        Returns:
            str: Generated caption for the image.
        """
        if feedback_prompt:
            user_prompt = f"""
                Last time the caption you generated for this image was incorrect because of the following reasons:
                {feedback_prompt}

                Try to generate a better caption for the image.
            """
        else:
            user_prompt = "Describe the given image."

        messages=[
            {"role": "system", "content": self.caption_prompt},
            {
                "role": "user",
                "content": [
                    {
                        "type": "image_url",
                        "image_url": {
                            "url": f"data:image/jpeg;base64,{image}",
                            "detail": self.detail
                        },
                    },
                    {"type": "text", "text": user_prompt},
                ]
            }
        ]

        try:
            response = self.client.chat.completions.create(
                model=self.caption_model,
                messages=messages,
                **kwargs
            )
            image_caption = response.choices[0].message.content.strip()

        except Exception as e:
            logger.error(f"Image captioning failed: {e}")
            image_caption = "ERROR"

        return image_caption

    def validate(self, image: str, caption: str, **kwargs) -> Tuple[str, float]: 
        """
        Validates the caption generated for the image.

        Args:
            image (str): Base64 encoded image.
            caption (str): Caption generated for the image.

        Returns:
            Tuple[str, float]: Validation reasoning and confidence score for the caption.
        """
        if caption == "ERROR":
            return "ERROR", 0

        messages = [
            {
                "role": "system",
                "content": self.validation_prompt
            },
            {
                "role": "user",
                "content": [
                    {
                        "type": "image_url",
                        "image_url": {
                            "url": f"data:image/jpeg;base64,{image}",
                            "detail": self.detail
                        },
                    },
                    {
                        "type": "text",
                        "text": caption + "\nValidate the caption generated for the given image."
                    }
                ]
            }
        ]      

        try:
            response = self.client.chat.completions.create(
                model=self.validation_model,
                messages=messages,
                response_format=ImageValidationOutputOpenAI,
                **kwargs
            )
            validation_output = response.choices[0].message.parsed
            validation_reasoning = validation_output.validation_reasoning
            confidence = validation_output.confidence

        except Exception as e:
            logger.error(f"Image caption validation failed: {e}")
            validation_reasoning = "ERROR"
            confidence = 0

        return validation_reasoning, confidence

    def generate(self, image_paths: List[str], **kwargs) -> List[Dict]:
        """
        Generates captions for a list of images. Implements the logic to generate captions for a list of images.

        Args:
            image_paths (List[str]): List of image paths to generate captions for.
            **kwargs: Additional arguments to pass to the method for custom pipeline interactions. To control generation parameters for the model.

        Returns:
            List[Dict]: List of captions, validation reasoning and confidence scores for each image.
        """
        results = super().generate(
            image_paths=image_paths, 
            **kwargs
        )

        return results

__init__(caption_model, validation_model, api_key, caption_prompt=BASE_IMAGE_CAPTION_PROMPT, validation=True, validation_prompt=BASE_IMAGE_CAPTION_VALIDATION_PROMPT, validation_threshold=0.5, max_retry=3, output_file=None, **kwargs)

Initializes the ImageCaptioningOpenAI pipeline.

Parameters:

Name Type Description Default
caption_model str

Can be either "gpt-4o", "gpt-4o-mini", etc. or specific versions of model supported by OpenAI.

required
validation_model str

Can be either "gpt-4o", "gpt-4o-mini", etc. or specific versions of model supported by OpenAI.

required
api_key str

OpenAI API key.

required
caption_prompt str | None

System prompt for captioning images. Uses default BASE_IMAGE_CAPTION_PROMPT prompt if not provided.

BASE_IMAGE_CAPTION_PROMPT
validation bool

Use validation step or not. Defaults to True.

True
validation_prompt str | None

System prompt for validating image captions should specify the range of validation score to be generated. Uses default BASE_IMAGE_CAPTION_PROMPT prompt if not provided.

BASE_IMAGE_CAPTION_VALIDATION_PROMPT
validation_threshold float

Threshold to determine if image caption is valid or not should be within specified range for validation score. Defaults to 0.5.

0.5
max_retry int

Number of retries before giving up on the image caption. Defaults to 3.

3
output_file str | None

Output file path, only JSON is supported for now. Defaults to None.

None

Other Parameters:

Name Type Description
detail str

Specific to OpenAI. Detail level of the image (Higher resolution costs more). Defaults to "low".

Notes

validation_prompt should specify the rules for validating the caption and the range of validation score to be generated example (0-1). Your validation_threshold should be within this specified range.

Source code in swiftannotate/image/captioning/openai.py
def __init__(
    self, 
    caption_model: str, 
    validation_model: str,
    api_key: str, 
    caption_prompt: str = BASE_IMAGE_CAPTION_PROMPT, 
    validation: bool = True,
    validation_prompt: str = BASE_IMAGE_CAPTION_VALIDATION_PROMPT,
    validation_threshold: float = 0.5,
    max_retry: int = 3, 
    output_file: str | None = None,
    **kwargs
):
    """
    Initializes the ImageCaptioningOpenAI pipeline.

    Args:
        caption_model (str): 
            Can be either "gpt-4o", "gpt-4o-mini", etc. or 
            specific versions of model supported by OpenAI.
        validation_model (str): 
            Can be either "gpt-4o", "gpt-4o-mini", etc. or 
            specific versions of model supported by OpenAI.
        api_key (str): OpenAI API key.
        caption_prompt (str | None, optional): 
            System prompt for captioning images.
            Uses default BASE_IMAGE_CAPTION_PROMPT prompt if not provided.
        validation (bool, optional): 
            Use validation step or not. Defaults to True.
        validation_prompt (str | None, optional): 
            System prompt for validating image captions should specify the range of validation score to be generated. 
            Uses default BASE_IMAGE_CAPTION_PROMPT prompt if not provided.
        validation_threshold (float, optional): 
            Threshold to determine if image caption is valid or not should be within specified range for validation score. 
            Defaults to 0.5.
        max_retry (int, optional):
            Number of retries before giving up on the image caption. 
            Defaults to 3.
        output_file (str | None, optional): 
            Output file path, only JSON is supported for now. 
            Defaults to None.

    Keyword Arguments:
        detail (str, optional): 
            Specific to OpenAI. Detail level of the image (Higher resolution costs more). Defaults to "low".

    Notes:
        `validation_prompt` should specify the rules for validating the caption and the range of validation score to be generated example (0-1).
        Your `validation_threshold` should be within this specified range.
    """
    self.caption_model = caption_model
    self.validation_model = validation_model
    self.client = OpenAI(api_key)

    super().__init__(
        caption_prompt=caption_prompt,
        validation=validation,
        validation_prompt=validation_prompt,
        validation_threshold=validation_threshold,
        max_retry=max_retry,
        output_file=output_file
    )

    self.detail = kwargs.get("detail", "low")

annotate(image, feedback_prompt='', **kwargs)

Annotates the image with a caption. Implements the logic to generate captions for an image.

Note: The feedback_prompt is dynamically updated using the validation reasoning from the previous iteration in case the caption does not pass validation threshold.

Parameters:

Name Type Description Default
image str

Base64 encoded image.

required
feedback_prompt str

Feedback prompt for the user to generate a better caption. Defaults to ''.

''
**kwargs

Additional arguments to pass to the method for custom pipeline interactions. To control generation parameters for the model.

{}

Returns:

Name Type Description
str str

Generated caption for the image.

Source code in swiftannotate/image/captioning/openai.py
def annotate(self, image: str, feedback_prompt:str = "", **kwargs) -> str:        
    """
    Annotates the image with a caption. Implements the logic to generate captions for an image.

    **Note**: The feedback_prompt is dynamically updated using the validation reasoning from 
    the previous iteration in case the caption does not pass validation threshold.

    Args:
        image (str): Base64 encoded image.
        feedback_prompt (str, optional): Feedback prompt for the user to generate a better caption. Defaults to ''.
        **kwargs: Additional arguments to pass to the method for custom pipeline interactions. To control generation parameters for the model.

    Returns:
        str: Generated caption for the image.
    """
    if feedback_prompt:
        user_prompt = f"""
            Last time the caption you generated for this image was incorrect because of the following reasons:
            {feedback_prompt}

            Try to generate a better caption for the image.
        """
    else:
        user_prompt = "Describe the given image."

    messages=[
        {"role": "system", "content": self.caption_prompt},
        {
            "role": "user",
            "content": [
                {
                    "type": "image_url",
                    "image_url": {
                        "url": f"data:image/jpeg;base64,{image}",
                        "detail": self.detail
                    },
                },
                {"type": "text", "text": user_prompt},
            ]
        }
    ]

    try:
        response = self.client.chat.completions.create(
            model=self.caption_model,
            messages=messages,
            **kwargs
        )
        image_caption = response.choices[0].message.content.strip()

    except Exception as e:
        logger.error(f"Image captioning failed: {e}")
        image_caption = "ERROR"

    return image_caption

generate(image_paths, **kwargs)

Generates captions for a list of images. Implements the logic to generate captions for a list of images.

Parameters:

Name Type Description Default
image_paths List[str]

List of image paths to generate captions for.

required
**kwargs

Additional arguments to pass to the method for custom pipeline interactions. To control generation parameters for the model.

{}

Returns:

Type Description
List[Dict]

List[Dict]: List of captions, validation reasoning and confidence scores for each image.

Source code in swiftannotate/image/captioning/openai.py
def generate(self, image_paths: List[str], **kwargs) -> List[Dict]:
    """
    Generates captions for a list of images. Implements the logic to generate captions for a list of images.

    Args:
        image_paths (List[str]): List of image paths to generate captions for.
        **kwargs: Additional arguments to pass to the method for custom pipeline interactions. To control generation parameters for the model.

    Returns:
        List[Dict]: List of captions, validation reasoning and confidence scores for each image.
    """
    results = super().generate(
        image_paths=image_paths, 
        **kwargs
    )

    return results

validate(image, caption, **kwargs)

Validates the caption generated for the image.

Parameters:

Name Type Description Default
image str

Base64 encoded image.

required
caption str

Caption generated for the image.

required

Returns:

Type Description
Tuple[str, float]

Tuple[str, float]: Validation reasoning and confidence score for the caption.

Source code in swiftannotate/image/captioning/openai.py
def validate(self, image: str, caption: str, **kwargs) -> Tuple[str, float]: 
    """
    Validates the caption generated for the image.

    Args:
        image (str): Base64 encoded image.
        caption (str): Caption generated for the image.

    Returns:
        Tuple[str, float]: Validation reasoning and confidence score for the caption.
    """
    if caption == "ERROR":
        return "ERROR", 0

    messages = [
        {
            "role": "system",
            "content": self.validation_prompt
        },
        {
            "role": "user",
            "content": [
                {
                    "type": "image_url",
                    "image_url": {
                        "url": f"data:image/jpeg;base64,{image}",
                        "detail": self.detail
                    },
                },
                {
                    "type": "text",
                    "text": caption + "\nValidate the caption generated for the given image."
                }
            ]
        }
    ]      

    try:
        response = self.client.chat.completions.create(
            model=self.validation_model,
            messages=messages,
            response_format=ImageValidationOutputOpenAI,
            **kwargs
        )
        validation_output = response.choices[0].message.parsed
        validation_reasoning = validation_output.validation_reasoning
        confidence = validation_output.confidence

    except Exception as e:
        logger.error(f"Image caption validation failed: {e}")
        validation_reasoning = "ERROR"
        confidence = 0

    return validation_reasoning, confidence

Qwen2VLForImageCaptioning

Bases: BaseImageCaptioning

Qwen2VLForImageCaptioning pipeline using Qwen2VL model.

Example usage:

from transformers import AutoProcessor, AutoModelForImageTextToText
from transformers import BitsAndBytesConfig
from swiftannotate.image import Qwen2VLForImageCaptioning

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype="float16",
    bnb_4bit_use_double_quant=True
)

model = AutoModelForImageTextToText.from_pretrained(
    "Qwen/Qwen2-VL-7B-Instruct",
    device_map="auto",
    torch_dtype="auto",
    quantization_config=quantization_config)

processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")

# Load the Caption Model
captioning_pipeline = Qwen2VLForImageCaptioning(
    model = model,
    processor = processor,
    output_file="captions.json"
)

# Generate captions for images
image_paths = ['path/to/image1.jpg']
results = captioning_pipeline.generate(image_paths)

# Print results
# Output: [
#     {
#         'image_path': 'path/to/image1.jpg', 
#         'image_caption': 'A cat sitting on a table.', 
#         'validation_reasoning': 'The caption is valid.', 
#         'validation_score': 0.8
#     }, 
# ]

Source code in swiftannotate/image/captioning/qwen.py
class Qwen2VLForImageCaptioning(BaseImageCaptioning):
    """
    Qwen2VLForImageCaptioning pipeline using Qwen2VL model.

    Example usage:
    ```python
    from transformers import AutoProcessor, AutoModelForImageTextToText
    from transformers import BitsAndBytesConfig
    from swiftannotate.image import Qwen2VLForImageCaptioning

    quantization_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype="float16",
        bnb_4bit_use_double_quant=True
    )

    model = AutoModelForImageTextToText.from_pretrained(
        "Qwen/Qwen2-VL-7B-Instruct",
        device_map="auto",
        torch_dtype="auto",
        quantization_config=quantization_config)

    processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")

    # Load the Caption Model
    captioning_pipeline = Qwen2VLForImageCaptioning(
        model = model,
        processor = processor,
        output_file="captions.json"
    )

    # Generate captions for images
    image_paths = ['path/to/image1.jpg']
    results = captioning_pipeline.generate(image_paths)

    # Print results
    # Output: [
    #     {
    #         'image_path': 'path/to/image1.jpg', 
    #         'image_caption': 'A cat sitting on a table.', 
    #         'validation_reasoning': 'The caption is valid.', 
    #         'validation_score': 0.8
    #     }, 
    # ]
    ```
    """
    def __init__(
        self, 
        model: AutoModelForImageTextToText | Qwen2VLForConditionalGeneration, 
        processor: AutoProcessor | Qwen2VLProcessor,
        caption_prompt: str = BASE_IMAGE_CAPTION_PROMPT, 
        validation: bool = True,
        validation_prompt: str = BASE_IMAGE_CAPTION_VALIDATION_PROMPT,
        validation_threshold: float = 0.5,
        max_retry: int = 3, 
        output_file: str | None = None,
        **kwargs
    ):     
        """
        Initializes the ImageCaptioningQwen2VL pipeline.

        Args:
            model (AutoModelForImageTextToText): 
                Model for image captioning. Should be an instance of AutoModelForImageTextToText with Qwen2-VL pretrained weights.
                Can be any version of Qwen2-VL model (7B, 72B).
            processor (AutoProcessor): 
                Processor for the Qwen2-VL model. Should be an instance of AutoProcessor.
            caption_prompt (str | None, optional): 
                System prompt for captioning images.
                Uses default BASE_IMAGE_CAPTION_PROMPT prompt if not provided.
            validation (bool, optional): 
                Use validation step or not. Defaults to True.
            validation_prompt (str | None, optional): 
                System prompt for validating image captions should specify the range of validation score to be generated. 
                Uses default BASE_IMAGE_CAPTION_PROMPT prompt if not provided.
            validation_threshold (float, optional): 
                Threshold to determine if image caption is valid or not should be within specified range for validation score. 
                Defaults to 0.5.
            max_retry (int, optional):
                Number of retries before giving up on the image caption. 
                Defaults to 3.
            output_file (str | None, optional): 
                Output file path, only JSON is supported for now. 
                Defaults to None.

        Keyword Arguments:
            resize_height (int, optional):
                Height to resize the image before generating captions. Defaults to 280.
            resize_width (int, optional):
                Width to resize the image before generating captions. Defaults to 420.

        Notes:
            `validation_prompt` should specify the rules for validating the caption and the range of validation score to be generated example (0-1).
            Your `validation_threshold` should be within this specified range.
        """    

        if not isinstance(model, Qwen2VLForConditionalGeneration):
            raise ValueError("Model should be an instance of Qwen2VLForConditionalGeneration.")
        if not isinstance(processor, Qwen2VLProcessor):
            raise ValueError("Processor should be an instance of Qwen2VLProcessor.")

        self.model = model
        self.processor = processor

        super().__init__(
            caption_prompt=caption_prompt,
            validation=validation,
            validation_prompt=validation_prompt,
            validation_threshold=validation_threshold,
            max_retry=max_retry,
            output_file=output_file
        )

        self.resize_height = kwargs.get("resize_height", 280)
        self.resize_width = kwargs.get("resize_width", 420)

    def annotate(self, image: str, feedback_prompt:str = "", **kwargs) -> str:
        """
        Annotates the image with a caption. Implements the logic to generate captions for an image.

        **Note**: The feedback_prompt is dynamically updated using the validation reasoning from 
        the previous iteration in case the caption does not pass validation threshold.

        Args:
            image (str): Base64 encoded image.
            feedback_prompt (str, optional): Feedback prompt for the user to generate a better caption. Defaults to ''.
            **kwargs: Additional arguments to pass to the method for custom pipeline interactions. To control generation parameters for the model.

        Returns:
            str: Generated caption for the image.
        """
        if feedback_prompt:
            user_prompt = f"""
                Last time the caption you generated for this image was incorrect because of the following reasons:
                {feedback_prompt}

                Try to generate a better caption for the image.
            """
        else:
            user_prompt = "Describe the given image."

        messages = [
            {"role": "system", "content": self.caption_prompt},
            {
                "role": "user", 
                "content": [
                    {
                        "type": "image", 
                        "image": f"data:image;base64,{image}",
                        "resized_height": self.resize_height,
                        "resized_width": self.resize_width,
                    },
                    {"type": "text", "text": user_prompt},
                ],
            },
        ]

        text = self.processor.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )

        image_inputs, video_inputs = process_vision_info(messages)
        inputs = self.processor(
            text=[text],
            images=image_inputs,
            videos=video_inputs,
            padding=True,
            return_tensors="pt",
        )
        inputs = inputs.to(self.model.device)

        # Inference: Generation of the output
        if "max_new_tokens" not in kwargs:
            kwargs["max_new_tokens"] = 512

        generated_ids = self.model.generate(**inputs, **kwargs)
        generated_ids_trimmed = [
            out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
        ]
        image_caption = self.processor.batch_decode(
            generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )[0]

        return image_caption

    def validate(self, image: str, caption: str, **kwargs) -> Tuple[str, float]:
        """
        Validates the caption generated for the image.

        Args:
            image (str): Base64 encoded image.
            caption (str): Caption generated for the image.

        Returns:
            Tuple[str, float]: Validation reasoning and confidence score for the caption.
        """
        messages = [
            {"role": "system", "content": self.validation_prompt},
            {
                "role": "user", 
                "content": [
                    {
                        "type": "image", 
                        "image": f"data:image;base64,{image}",
                        "resized_height": self.resize_height,
                        "resized_width": self.resize_width,
                    },
                    {"type": "text", "text": caption},
                    {
                        "type": "text", 
                        "text": """
                        Validate the caption generated for the given image. 
                        Return output as a JSON object with keys as 'validation_reasoning' and 'confidence'.
                        """
                    },
                ],
            },
        ]

        text = self.processor.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )

        image_inputs, video_inputs = process_vision_info(messages)
        inputs = self.processor(
            text=[text],
            images=image_inputs,
            videos=video_inputs,
            padding=True,
            return_tensors="pt",
        )
        inputs = inputs.to(self.model.device)

        # Inference: Generation of the output
        if "max_new_tokens" not in kwargs:
            kwargs["max_new_tokens"] = 512

        generated_ids = self.model.generate(**inputs, **kwargs)
        generated_ids_trimmed = [
            out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
        ]
        validation_output = self.processor.batch_decode(
            generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )[0]

        # TODO: Need a better way to parse the output
        try:
            validation_output = validation_output.replace('```', '').replace('json', '')
            validation_output = json.loads(validation_output)
            validation_reasoning = validation_output["validation_reasoning"]
            confidence = validation_output["confidence"]
        except Exception as e:
            logger.error(f"Image caption validation parsing failed trying to parse using another logic.")

            number_str  = ''.join((ch if ch in '0123456789.-e' else ' ') for ch in validation_output)
            number_str = [i for i in number_str.split() if i.isalnum()]
            potential_confidence_scores = [float(i) for i in number_str if float(i) >= 0 and float(i) <= 1]
            confidence = max(potential_confidence_scores) if potential_confidence_scores else 0.0
            validation_reasoning = validation_output

        return validation_reasoning, confidence

    def generate(self, image_paths: List[str], **kwargs) -> List[Dict]:
        """
        Generates captions for a list of images. Implements the logic to generate captions for a list of images.

        Args:
            image_paths (List[str]): List of image paths to generate captions for.
            **kwargs: Additional arguments to pass to the method for custom pipeline interactions. To control generation parameters for the model.

        Returns:
            List[Dict]: List of captions, validation reasoning and confidence scores for each image.
        """
        results = super().generate(
            image_paths=image_paths, 
            **kwargs
        )

        return results

__init__(model, processor, caption_prompt=BASE_IMAGE_CAPTION_PROMPT, validation=True, validation_prompt=BASE_IMAGE_CAPTION_VALIDATION_PROMPT, validation_threshold=0.5, max_retry=3, output_file=None, **kwargs)

Initializes the ImageCaptioningQwen2VL pipeline.

Parameters:

Name Type Description Default
model AutoModelForImageTextToText

Model for image captioning. Should be an instance of AutoModelForImageTextToText with Qwen2-VL pretrained weights. Can be any version of Qwen2-VL model (7B, 72B).

required
processor AutoProcessor

Processor for the Qwen2-VL model. Should be an instance of AutoProcessor.

required
caption_prompt str | None

System prompt for captioning images. Uses default BASE_IMAGE_CAPTION_PROMPT prompt if not provided.

BASE_IMAGE_CAPTION_PROMPT
validation bool

Use validation step or not. Defaults to True.

True
validation_prompt str | None

System prompt for validating image captions should specify the range of validation score to be generated. Uses default BASE_IMAGE_CAPTION_PROMPT prompt if not provided.

BASE_IMAGE_CAPTION_VALIDATION_PROMPT
validation_threshold float

Threshold to determine if image caption is valid or not should be within specified range for validation score. Defaults to 0.5.

0.5
max_retry int

Number of retries before giving up on the image caption. Defaults to 3.

3
output_file str | None

Output file path, only JSON is supported for now. Defaults to None.

None

Other Parameters:

Name Type Description
resize_height int

Height to resize the image before generating captions. Defaults to 280.

resize_width int

Width to resize the image before generating captions. Defaults to 420.

Notes

validation_prompt should specify the rules for validating the caption and the range of validation score to be generated example (0-1). Your validation_threshold should be within this specified range.

Source code in swiftannotate/image/captioning/qwen.py
def __init__(
    self, 
    model: AutoModelForImageTextToText | Qwen2VLForConditionalGeneration, 
    processor: AutoProcessor | Qwen2VLProcessor,
    caption_prompt: str = BASE_IMAGE_CAPTION_PROMPT, 
    validation: bool = True,
    validation_prompt: str = BASE_IMAGE_CAPTION_VALIDATION_PROMPT,
    validation_threshold: float = 0.5,
    max_retry: int = 3, 
    output_file: str | None = None,
    **kwargs
):     
    """
    Initializes the ImageCaptioningQwen2VL pipeline.

    Args:
        model (AutoModelForImageTextToText): 
            Model for image captioning. Should be an instance of AutoModelForImageTextToText with Qwen2-VL pretrained weights.
            Can be any version of Qwen2-VL model (7B, 72B).
        processor (AutoProcessor): 
            Processor for the Qwen2-VL model. Should be an instance of AutoProcessor.
        caption_prompt (str | None, optional): 
            System prompt for captioning images.
            Uses default BASE_IMAGE_CAPTION_PROMPT prompt if not provided.
        validation (bool, optional): 
            Use validation step or not. Defaults to True.
        validation_prompt (str | None, optional): 
            System prompt for validating image captions should specify the range of validation score to be generated. 
            Uses default BASE_IMAGE_CAPTION_PROMPT prompt if not provided.
        validation_threshold (float, optional): 
            Threshold to determine if image caption is valid or not should be within specified range for validation score. 
            Defaults to 0.5.
        max_retry (int, optional):
            Number of retries before giving up on the image caption. 
            Defaults to 3.
        output_file (str | None, optional): 
            Output file path, only JSON is supported for now. 
            Defaults to None.

    Keyword Arguments:
        resize_height (int, optional):
            Height to resize the image before generating captions. Defaults to 280.
        resize_width (int, optional):
            Width to resize the image before generating captions. Defaults to 420.

    Notes:
        `validation_prompt` should specify the rules for validating the caption and the range of validation score to be generated example (0-1).
        Your `validation_threshold` should be within this specified range.
    """    

    if not isinstance(model, Qwen2VLForConditionalGeneration):
        raise ValueError("Model should be an instance of Qwen2VLForConditionalGeneration.")
    if not isinstance(processor, Qwen2VLProcessor):
        raise ValueError("Processor should be an instance of Qwen2VLProcessor.")

    self.model = model
    self.processor = processor

    super().__init__(
        caption_prompt=caption_prompt,
        validation=validation,
        validation_prompt=validation_prompt,
        validation_threshold=validation_threshold,
        max_retry=max_retry,
        output_file=output_file
    )

    self.resize_height = kwargs.get("resize_height", 280)
    self.resize_width = kwargs.get("resize_width", 420)

annotate(image, feedback_prompt='', **kwargs)

Annotates the image with a caption. Implements the logic to generate captions for an image.

Note: The feedback_prompt is dynamically updated using the validation reasoning from the previous iteration in case the caption does not pass validation threshold.

Parameters:

Name Type Description Default
image str

Base64 encoded image.

required
feedback_prompt str

Feedback prompt for the user to generate a better caption. Defaults to ''.

''
**kwargs

Additional arguments to pass to the method for custom pipeline interactions. To control generation parameters for the model.

{}

Returns:

Name Type Description
str str

Generated caption for the image.

Source code in swiftannotate/image/captioning/qwen.py
def annotate(self, image: str, feedback_prompt:str = "", **kwargs) -> str:
    """
    Annotates the image with a caption. Implements the logic to generate captions for an image.

    **Note**: The feedback_prompt is dynamically updated using the validation reasoning from 
    the previous iteration in case the caption does not pass validation threshold.

    Args:
        image (str): Base64 encoded image.
        feedback_prompt (str, optional): Feedback prompt for the user to generate a better caption. Defaults to ''.
        **kwargs: Additional arguments to pass to the method for custom pipeline interactions. To control generation parameters for the model.

    Returns:
        str: Generated caption for the image.
    """
    if feedback_prompt:
        user_prompt = f"""
            Last time the caption you generated for this image was incorrect because of the following reasons:
            {feedback_prompt}

            Try to generate a better caption for the image.
        """
    else:
        user_prompt = "Describe the given image."

    messages = [
        {"role": "system", "content": self.caption_prompt},
        {
            "role": "user", 
            "content": [
                {
                    "type": "image", 
                    "image": f"data:image;base64,{image}",
                    "resized_height": self.resize_height,
                    "resized_width": self.resize_width,
                },
                {"type": "text", "text": user_prompt},
            ],
        },
    ]

    text = self.processor.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )

    image_inputs, video_inputs = process_vision_info(messages)
    inputs = self.processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    )
    inputs = inputs.to(self.model.device)

    # Inference: Generation of the output
    if "max_new_tokens" not in kwargs:
        kwargs["max_new_tokens"] = 512

    generated_ids = self.model.generate(**inputs, **kwargs)
    generated_ids_trimmed = [
        out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]
    image_caption = self.processor.batch_decode(
        generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )[0]

    return image_caption

generate(image_paths, **kwargs)

Generates captions for a list of images. Implements the logic to generate captions for a list of images.

Parameters:

Name Type Description Default
image_paths List[str]

List of image paths to generate captions for.

required
**kwargs

Additional arguments to pass to the method for custom pipeline interactions. To control generation parameters for the model.

{}

Returns:

Type Description
List[Dict]

List[Dict]: List of captions, validation reasoning and confidence scores for each image.

Source code in swiftannotate/image/captioning/qwen.py
def generate(self, image_paths: List[str], **kwargs) -> List[Dict]:
    """
    Generates captions for a list of images. Implements the logic to generate captions for a list of images.

    Args:
        image_paths (List[str]): List of image paths to generate captions for.
        **kwargs: Additional arguments to pass to the method for custom pipeline interactions. To control generation parameters for the model.

    Returns:
        List[Dict]: List of captions, validation reasoning and confidence scores for each image.
    """
    results = super().generate(
        image_paths=image_paths, 
        **kwargs
    )

    return results

validate(image, caption, **kwargs)

Validates the caption generated for the image.

Parameters:

Name Type Description Default
image str

Base64 encoded image.

required
caption str

Caption generated for the image.

required

Returns:

Type Description
Tuple[str, float]

Tuple[str, float]: Validation reasoning and confidence score for the caption.

Source code in swiftannotate/image/captioning/qwen.py
def validate(self, image: str, caption: str, **kwargs) -> Tuple[str, float]:
    """
    Validates the caption generated for the image.

    Args:
        image (str): Base64 encoded image.
        caption (str): Caption generated for the image.

    Returns:
        Tuple[str, float]: Validation reasoning and confidence score for the caption.
    """
    messages = [
        {"role": "system", "content": self.validation_prompt},
        {
            "role": "user", 
            "content": [
                {
                    "type": "image", 
                    "image": f"data:image;base64,{image}",
                    "resized_height": self.resize_height,
                    "resized_width": self.resize_width,
                },
                {"type": "text", "text": caption},
                {
                    "type": "text", 
                    "text": """
                    Validate the caption generated for the given image. 
                    Return output as a JSON object with keys as 'validation_reasoning' and 'confidence'.
                    """
                },
            ],
        },
    ]

    text = self.processor.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )

    image_inputs, video_inputs = process_vision_info(messages)
    inputs = self.processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    )
    inputs = inputs.to(self.model.device)

    # Inference: Generation of the output
    if "max_new_tokens" not in kwargs:
        kwargs["max_new_tokens"] = 512

    generated_ids = self.model.generate(**inputs, **kwargs)
    generated_ids_trimmed = [
        out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]
    validation_output = self.processor.batch_decode(
        generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )[0]

    # TODO: Need a better way to parse the output
    try:
        validation_output = validation_output.replace('```', '').replace('json', '')
        validation_output = json.loads(validation_output)
        validation_reasoning = validation_output["validation_reasoning"]
        confidence = validation_output["confidence"]
    except Exception as e:
        logger.error(f"Image caption validation parsing failed trying to parse using another logic.")

        number_str  = ''.join((ch if ch in '0123456789.-e' else ' ') for ch in validation_output)
        number_str = [i for i in number_str.split() if i.isalnum()]
        potential_confidence_scores = [float(i) for i in number_str if float(i) >= 0 and float(i) <= 1]
        confidence = max(potential_confidence_scores) if potential_confidence_scores else 0.0
        validation_reasoning = validation_output

    return validation_reasoning, confidence

Classification

GeminiForImageClassification

Bases: BaseImageClassification

GeminiForImageClassification pipeline for generating captions for images using Gemini models.

Example usage:

from swiftannotate.image import GeminiForImageClassification

# Initialize the pipeline
classification_pipeline = GeminiForImageClassification(
    caption_model="gemini-1.5-pro",
    validation_model="gemini-1.5-flash",
    api_key="your_api_key_here",
    classification_labels=["kitchen", "bedroom", "living room"],
    output_file="captions.json"
)

# Generate captions for a list of images
image_paths = ["path/to/image1.jpg"]
results = classification_pipeline.generate(image_paths)

# Print results
# Output: [
#     {
#         "image_path": 'path/to/image1.jpg', 
#         "image_classification": 'kitchen', 
#         "validation_reasoning": 'The class label is valid.', 
#         "validation_score": 0.6
#     }, 
# ]

Source code in swiftannotate/image/classification/gemini.py
class GeminiForImageClassification(BaseImageClassification):
    """
    GeminiForImageClassification pipeline for generating captions for images using Gemini models.

    Example usage:
    ```python
    from swiftannotate.image import GeminiForImageClassification

    # Initialize the pipeline
    classification_pipeline = GeminiForImageClassification(
        caption_model="gemini-1.5-pro",
        validation_model="gemini-1.5-flash",
        api_key="your_api_key_here",
        classification_labels=["kitchen", "bedroom", "living room"],
        output_file="captions.json"
    )

    # Generate captions for a list of images
    image_paths = ["path/to/image1.jpg"]
    results = classification_pipeline.generate(image_paths)

    # Print results
    # Output: [
    #     {
    #         "image_path": 'path/to/image1.jpg', 
    #         "image_classification": 'kitchen', 
    #         "validation_reasoning": 'The class label is valid.', 
    #         "validation_score": 0.6
    #     }, 
    # ]
    ```
    """

    def __init__(
        self, 
        classification_model: str, 
        validation_model: str,
        api_key: str, 
        classification_labels: List[str],
        classification_prompt: str = BASE_IMAGE_CLASSIFICATION_PROMPT, 
        validation: bool = True,
        validation_prompt: str = BASE_IMAGE_CLASSIFICATION_VALIDATION_PROMPT,
        validation_threshold: float = 0.5,
        max_retry: int = 3, 
        output_file: str | None = None,
    ):
        """
        Initializes the GeminiForImageClassification pipeline.

        Args:
            classification_model (str): 
                Can be either "gemini-1.5-flash", "gemini-1.5-pro", etc. or specific versions of model supported by Gemini.
            validation_model (str): 
                Can be either "gemini-1.5-flash", "gemini-1.5-pro", etc. or specific versions of model supported by Gemini.
            api_key (str): 
                Google Gemini API key.
            classification_labels (List[str]):
                List of classification labels to be used for the image classification.
            classification_prompt (str | None, optional): 
                System prompt for classification images.
                Uses default BASE_IMAGE_CLASSIFICATION_PROMPT prompt if not provided.
            validation (bool, optional): 
                Use validation step or not. Defaults to True.
            validation_prompt (str | None, optional): 
                System prompt for validating image class labels should specify the range of validation score to be generated. 
                Uses default BASE_IMAGE_CLASSIFICATION_PROMPT prompt if not provided.
            validation_threshold (float, optional): 
                Threshold to determine if image class labels is valid or not should be within specified range for validation score. 
                Defaults to 0.5.
            max_retry (int, optional):
                Number of retries before giving up on the image class labels. 
                Defaults to 3.
            output_file (str | None, optional): 
                Output file path, only JSON is supported for now. 
                Defaults to None.

        Notes:
            `validation_prompt` should specify the rules for validating the class label and the range of validation score to be generated example (0-1).
            Your `validation_threshold` should be within this specified range.

            It is advised to include class descriptions in the classification_prompt and validation_prompt to help the model understand the context of the class labels.
            You can also add Few-shot learning examples to the prompt to help the model understand the context of the class labels.
        """        
        genai.configure(api_key=api_key)
        self.classification_model = genai.GenerativeModel(model=classification_model)
        self.validation_model = genai.GenerativeModel(model=validation_model)

        super().__init__(
            classification_labels=classification_labels,
            classification_prompt=classification_prompt,
            validation=validation,
            validation_prompt=validation_prompt,
            validation_threshold=validation_threshold,
            max_retry=max_retry,
            output_file=output_file
        )

    def annotate(self, image: str, feedback_prompt:str = "", **kwargs) -> str:
        """
        Annotates the image with a class label. Implements the logic to generate class labels for an image.

        **Note**: The feedback_prompt is dynamically updated using the validation reasoning from 
        the previous iteration in case the calss label does not pass validation threshold.

        Args:
            image (str): Base64 encoded image.
            feedback_prompt (str, optional): Feedback prompt for the user to generate a better class label. Defaults to ''.
            **kwargs: Additional arguments to pass to the method for custom pipeline interactions. To control generation parameters for the model.

        Returns:
            str: Generated class label for the image.
        """
        if feedback_prompt:
            user_prompt = f"""
                Last time the class label you generated for this image was incorrect because of the following reasons:
                {feedback_prompt}

                Regenerate the class label for the given image.
                Classify the given image as {', '.join(map(str, self.classification_labels))}
            """
        else:
            user_prompt = f"Classify the given image as {', '.join(map(str, self.classification_labels))}"

        messages = [
            self.classification_prompt,
            {'mime_type':'image/jpeg', 'data': image}, 
            user_prompt
        ]

        try:
            output = self.classification_model.generate_content(
                messages,
                generation_config=genai.GenerationConfig(
                    response_mime_type="application/json", 
                    response_schema=ImageClassificationOutputGemini,
                    **kwargs
                )
            )
            class_label = output["class_label"].lower()
        except Exception as e:
            logger.error(f"Image classification failed: {e}")
            class_label = "ERROR"

        return class_label

    def validate(self, image: str, class_label: str, **kwargs) -> Tuple[str, float]:
        """
        Validates the class label generated for the image.

        Args:
            image (str): Base64 encoded image.
            class_label (str): Class Label generated for the image.

        Returns:
            Tuple[str, float]: Validation reasoning and confidence score for the class label.
        """
        if class_label == "ERROR":
            return "ERROR", 0.0

        messages = [
            self.validation_prompt,
            {'mime_type':'image/jpeg', 'data': image},
            class_label,
            "Validate the class label generated for the given image."
        ]

        try:
            validation_output = self.validation_model.generate_content(
                messages,
                generation_config=genai.GenerationConfig(
                    response_mime_type="application/json", 
                    response_schema=ImageValidationOutputGemini
                )
            )
            validation_reasoning = validation_output["validation_reasoning"]
            confidence = validation_output["confidence"]
        except Exception as e:
            logger.error(f"Image class label validation failed: {e}")
            validation_reasoning = "ERROR"
            confidence = 0.0

        return validation_reasoning, confidence

    def generate(self, image_paths: List[str], **kwargs) -> List[Dict]:
        """
        Generates class label for a list of images. 

        Args:
            image_paths (List[str]): List of image paths to generate class labels for.
            **kwargs: Additional arguments to pass to the method for custom pipeline interactions. To control generation parameters for the model.

        Returns:
            List[Dict]: List of class labels, validation reasoning and confidence scores for each image.
        """

        results = super().generate(
            image_paths=image_paths, 
            **kwargs
        )

        return results 

__init__(classification_model, validation_model, api_key, classification_labels, classification_prompt=BASE_IMAGE_CLASSIFICATION_PROMPT, validation=True, validation_prompt=BASE_IMAGE_CLASSIFICATION_VALIDATION_PROMPT, validation_threshold=0.5, max_retry=3, output_file=None)

Initializes the GeminiForImageClassification pipeline.

Parameters:

Name Type Description Default
classification_model str

Can be either "gemini-1.5-flash", "gemini-1.5-pro", etc. or specific versions of model supported by Gemini.

required
validation_model str

Can be either "gemini-1.5-flash", "gemini-1.5-pro", etc. or specific versions of model supported by Gemini.

required
api_key str

Google Gemini API key.

required
classification_labels List[str]

List of classification labels to be used for the image classification.

required
classification_prompt str | None

System prompt for classification images. Uses default BASE_IMAGE_CLASSIFICATION_PROMPT prompt if not provided.

BASE_IMAGE_CLASSIFICATION_PROMPT
validation bool

Use validation step or not. Defaults to True.

True
validation_prompt str | None

System prompt for validating image class labels should specify the range of validation score to be generated. Uses default BASE_IMAGE_CLASSIFICATION_PROMPT prompt if not provided.

BASE_IMAGE_CLASSIFICATION_VALIDATION_PROMPT
validation_threshold float

Threshold to determine if image class labels is valid or not should be within specified range for validation score. Defaults to 0.5.

0.5
max_retry int

Number of retries before giving up on the image class labels. Defaults to 3.

3
output_file str | None

Output file path, only JSON is supported for now. Defaults to None.

None
Notes

validation_prompt should specify the rules for validating the class label and the range of validation score to be generated example (0-1). Your validation_threshold should be within this specified range.

It is advised to include class descriptions in the classification_prompt and validation_prompt to help the model understand the context of the class labels. You can also add Few-shot learning examples to the prompt to help the model understand the context of the class labels.

Source code in swiftannotate/image/classification/gemini.py
def __init__(
    self, 
    classification_model: str, 
    validation_model: str,
    api_key: str, 
    classification_labels: List[str],
    classification_prompt: str = BASE_IMAGE_CLASSIFICATION_PROMPT, 
    validation: bool = True,
    validation_prompt: str = BASE_IMAGE_CLASSIFICATION_VALIDATION_PROMPT,
    validation_threshold: float = 0.5,
    max_retry: int = 3, 
    output_file: str | None = None,
):
    """
    Initializes the GeminiForImageClassification pipeline.

    Args:
        classification_model (str): 
            Can be either "gemini-1.5-flash", "gemini-1.5-pro", etc. or specific versions of model supported by Gemini.
        validation_model (str): 
            Can be either "gemini-1.5-flash", "gemini-1.5-pro", etc. or specific versions of model supported by Gemini.
        api_key (str): 
            Google Gemini API key.
        classification_labels (List[str]):
            List of classification labels to be used for the image classification.
        classification_prompt (str | None, optional): 
            System prompt for classification images.
            Uses default BASE_IMAGE_CLASSIFICATION_PROMPT prompt if not provided.
        validation (bool, optional): 
            Use validation step or not. Defaults to True.
        validation_prompt (str | None, optional): 
            System prompt for validating image class labels should specify the range of validation score to be generated. 
            Uses default BASE_IMAGE_CLASSIFICATION_PROMPT prompt if not provided.
        validation_threshold (float, optional): 
            Threshold to determine if image class labels is valid or not should be within specified range for validation score. 
            Defaults to 0.5.
        max_retry (int, optional):
            Number of retries before giving up on the image class labels. 
            Defaults to 3.
        output_file (str | None, optional): 
            Output file path, only JSON is supported for now. 
            Defaults to None.

    Notes:
        `validation_prompt` should specify the rules for validating the class label and the range of validation score to be generated example (0-1).
        Your `validation_threshold` should be within this specified range.

        It is advised to include class descriptions in the classification_prompt and validation_prompt to help the model understand the context of the class labels.
        You can also add Few-shot learning examples to the prompt to help the model understand the context of the class labels.
    """        
    genai.configure(api_key=api_key)
    self.classification_model = genai.GenerativeModel(model=classification_model)
    self.validation_model = genai.GenerativeModel(model=validation_model)

    super().__init__(
        classification_labels=classification_labels,
        classification_prompt=classification_prompt,
        validation=validation,
        validation_prompt=validation_prompt,
        validation_threshold=validation_threshold,
        max_retry=max_retry,
        output_file=output_file
    )

annotate(image, feedback_prompt='', **kwargs)

Annotates the image with a class label. Implements the logic to generate class labels for an image.

Note: The feedback_prompt is dynamically updated using the validation reasoning from the previous iteration in case the calss label does not pass validation threshold.

Parameters:

Name Type Description Default
image str

Base64 encoded image.

required
feedback_prompt str

Feedback prompt for the user to generate a better class label. Defaults to ''.

''
**kwargs

Additional arguments to pass to the method for custom pipeline interactions. To control generation parameters for the model.

{}

Returns:

Name Type Description
str str

Generated class label for the image.

Source code in swiftannotate/image/classification/gemini.py
def annotate(self, image: str, feedback_prompt:str = "", **kwargs) -> str:
    """
    Annotates the image with a class label. Implements the logic to generate class labels for an image.

    **Note**: The feedback_prompt is dynamically updated using the validation reasoning from 
    the previous iteration in case the calss label does not pass validation threshold.

    Args:
        image (str): Base64 encoded image.
        feedback_prompt (str, optional): Feedback prompt for the user to generate a better class label. Defaults to ''.
        **kwargs: Additional arguments to pass to the method for custom pipeline interactions. To control generation parameters for the model.

    Returns:
        str: Generated class label for the image.
    """
    if feedback_prompt:
        user_prompt = f"""
            Last time the class label you generated for this image was incorrect because of the following reasons:
            {feedback_prompt}

            Regenerate the class label for the given image.
            Classify the given image as {', '.join(map(str, self.classification_labels))}
        """
    else:
        user_prompt = f"Classify the given image as {', '.join(map(str, self.classification_labels))}"

    messages = [
        self.classification_prompt,
        {'mime_type':'image/jpeg', 'data': image}, 
        user_prompt
    ]

    try:
        output = self.classification_model.generate_content(
            messages,
            generation_config=genai.GenerationConfig(
                response_mime_type="application/json", 
                response_schema=ImageClassificationOutputGemini,
                **kwargs
            )
        )
        class_label = output["class_label"].lower()
    except Exception as e:
        logger.error(f"Image classification failed: {e}")
        class_label = "ERROR"

    return class_label

generate(image_paths, **kwargs)

Generates class label for a list of images.

Parameters:

Name Type Description Default
image_paths List[str]

List of image paths to generate class labels for.

required
**kwargs

Additional arguments to pass to the method for custom pipeline interactions. To control generation parameters for the model.

{}

Returns:

Type Description
List[Dict]

List[Dict]: List of class labels, validation reasoning and confidence scores for each image.

Source code in swiftannotate/image/classification/gemini.py
def generate(self, image_paths: List[str], **kwargs) -> List[Dict]:
    """
    Generates class label for a list of images. 

    Args:
        image_paths (List[str]): List of image paths to generate class labels for.
        **kwargs: Additional arguments to pass to the method for custom pipeline interactions. To control generation parameters for the model.

    Returns:
        List[Dict]: List of class labels, validation reasoning and confidence scores for each image.
    """

    results = super().generate(
        image_paths=image_paths, 
        **kwargs
    )

    return results 

validate(image, class_label, **kwargs)

Validates the class label generated for the image.

Parameters:

Name Type Description Default
image str

Base64 encoded image.

required
class_label str

Class Label generated for the image.

required

Returns:

Type Description
Tuple[str, float]

Tuple[str, float]: Validation reasoning and confidence score for the class label.

Source code in swiftannotate/image/classification/gemini.py
def validate(self, image: str, class_label: str, **kwargs) -> Tuple[str, float]:
    """
    Validates the class label generated for the image.

    Args:
        image (str): Base64 encoded image.
        class_label (str): Class Label generated for the image.

    Returns:
        Tuple[str, float]: Validation reasoning and confidence score for the class label.
    """
    if class_label == "ERROR":
        return "ERROR", 0.0

    messages = [
        self.validation_prompt,
        {'mime_type':'image/jpeg', 'data': image},
        class_label,
        "Validate the class label generated for the given image."
    ]

    try:
        validation_output = self.validation_model.generate_content(
            messages,
            generation_config=genai.GenerationConfig(
                response_mime_type="application/json", 
                response_schema=ImageValidationOutputGemini
            )
        )
        validation_reasoning = validation_output["validation_reasoning"]
        confidence = validation_output["confidence"]
    except Exception as e:
        logger.error(f"Image class label validation failed: {e}")
        validation_reasoning = "ERROR"
        confidence = 0.0

    return validation_reasoning, confidence

OllamaForImageClassification

Bases: BaseImageClassification

OllamaForImageClassification pipeline using OpenAI API.

Example usage:

from swiftannotate.image import OllamaForImageClassification

# Initialize the pipeline
classification_pipeline = OllamaForImageClassification(
    classification_model="llama3.2-vision",
    validation_model="llama3.2-vision",
    classification_labels=["kitchen", "bedroom", "living room"],
    output_file="captions.json"
)

# Generate captions for a list of images
image_paths = ["path/to/image1.jpg"]
results = classification_pipeline.generate(image_paths)

# Print results
# Output: [
#     {
#         "image_path": 'path/to/image1.jpg', 
#         "image_classification": 'kitchen', 
#         "validation_reasoning": 'The class label is valid.', 
#         "validation_score": 0.6
#     }, 
# ]
Source code in swiftannotate/image/classification/ollama.py
class OllamaForImageClassification(BaseImageClassification):
    """
    OllamaForImageClassification pipeline using OpenAI API.

    Example usage:

    ```python
    from swiftannotate.image import OllamaForImageClassification

    # Initialize the pipeline
    classification_pipeline = OllamaForImageClassification(
        classification_model="llama3.2-vision",
        validation_model="llama3.2-vision",
        classification_labels=["kitchen", "bedroom", "living room"],
        output_file="captions.json"
    )

    # Generate captions for a list of images
    image_paths = ["path/to/image1.jpg"]
    results = classification_pipeline.generate(image_paths)

    # Print results
    # Output: [
    #     {
    #         "image_path": 'path/to/image1.jpg', 
    #         "image_classification": 'kitchen', 
    #         "validation_reasoning": 'The class label is valid.', 
    #         "validation_score": 0.6
    #     }, 
    # ]
    ```
    """
    def __init__(
        self, 
        classification_model: str, 
        validation_model: str,
        classification_labels: List[str],
        classification_prompt: str = BASE_IMAGE_CLASSIFICATION_PROMPT, 
        validation: bool = True,
        validation_prompt: str = BASE_IMAGE_CLASSIFICATION_VALIDATION_PROMPT,
        validation_threshold: float = 0.5,
        max_retry: int = 3, 
        output_file: str | None = None,
    ):
        """
        Initializes the OllamaForImageClassification pipeline.

        Args:
            classification_model (str): 
                Can be either any of the Multimodal (Vision) models supported by Ollama.
                specific versions of model supported by Ollama.
            validation_model (str): 
                Can be either any of the Multimodal (Vision) models supported by Ollama.
                specific versions of model supported by Ollama.
            classification_labels (List[str]):
                List of classification labels to be used for the image classification.
            classification_prompt (str | None, optional): 
                System prompt for classification images.
                Uses default BASE_IMAGE_CLASSIFICATION_PROMPT prompt if not provided.
            validation (bool, optional): 
                Use validation step or not. Defaults to True.
            validation_prompt (str | None, optional): 
                System prompt for validating image class labels should specify the range of validation score to be generated. 
                Uses default BASE_IMAGE_CLASSIFICATION_PROMPT prompt if not provided.
            validation_threshold (float, optional): 
                Threshold to determine if image class labels is valid or not should be within specified range for validation score. 
                Defaults to 0.5.
            max_retry (int, optional):
                Number of retries before giving up on the image class labels. 
                Defaults to 3.
            output_file (str | None, optional): 
                Output file path, only JSON is supported for now. 
                Defaults to None.

        Notes:
            `validation_prompt` should specify the rules for validating the class label and the range of validation score to be generated example (0-1).
            Your `validation_threshold` should be within this specified range.

            It is advised to include class descriptions in the classification_prompt and validation_prompt to help the model understand the context of the class labels.
            You can also add Few-shot learning examples to the prompt to help the model understand the context of the class labels.
        """

        if not self._validate_ollama_model(classification_model):
            raise ValueError(f"Model {classification_model} is not supported by Ollama.")

        if not self._validate_ollama_model(validation_model):
            raise ValueError(f"Model {validation_model} is not supported by Ollama.")

        self.classification_model = classification_model
        self.validation_model = validation_model

        super().__init__(
            classification_labels=classification_labels,
            classification_prompt=classification_prompt,
            validation=validation,
            validation_prompt=validation_prompt,
            validation_threshold=validation_threshold,
            max_retry=max_retry,
            output_file=output_file
        )

    def _validate_ollama_model(self, model: str) -> bool:
        try:
            ollama.chat(model)
        except ollama.ResponseError as e:
            logger.error(f"Error: {e.error}")
            if e.status_code == 404:
                try:
                    ollama.pull(model)
                    logger.info(f"Model {model} is now downloaded.")
                except ollama.ResponseError as e:
                    logger.error(f"Error: {e.error}")
                    logger.error(f"Model {model} could not be downloaded. Check the model name and try again.")
                    return False
            logger.info(f"Model {model} is now downloaded.")

        return True

    def annotate(self, image: str, feedback_prompt:str = "", **kwargs) -> str:        
        """
        Annotates the image with a class label. Implements the logic to generate class labels for an image.

        **Note**: The feedback_prompt is dynamically updated using the validation reasoning from 
        the previous iteration in case the calss label does not pass validation threshold.

        Args:
            image (str): Base64 encoded image.
            feedback_prompt (str, optional): Feedback prompt for the user to generate a better class label. Defaults to ''.
            **kwargs: Additional arguments to pass to the method for custom pipeline interactions. To control generation parameters for the model.

        Returns:
            str: Generated class label for the image.
        """
        if feedback_prompt:
            user_prompt = f"""
                Last time the class label you generated for this image was incorrect because of the following reasons:
                {feedback_prompt}

                Regenerate the class label for the given image.
                Classify the given image as {', '.join(map(str, self.classification_labels))}
            """
        else:
            user_prompt = f"Classify the given image as {', '.join(map(str, self.classification_labels))}"

        messages=[
            {"role": "system", "content": self.classification_prompt},
            {
                "role": "user",
                "images": [image],
                "content": user_prompt,
            }
        ]

        if not "temperature" in kwargs:
            kwargs["temperature"] = 0.0

        try:  
            response = ollama.chat(
                model=self.classification_model,
                messages=messages,
                format=ImageClassificationOutputOllama.model_json_schema(),
                options=kwargs
            )

            output = ImageClassificationOutputOllama.model_validate_json(response.message.content)
            class_label = output.class_label.lower()

        except Exception as e:
            logger.error(f"Image classification failed: {e}")
            class_label = "ERROR"

        return class_label

    def validate(self, image: str, class_label: str, **kwargs) -> Tuple[str, float]: 
        """
        Validates the class label generated for the image.

        Args:
            image (str): Base64 encoded image.
            class_label (str): Class Label generated for the image.

        Returns:
            Tuple[str, float]: Validation reasoning and confidence score for the class label.
        """
        if class_label == "ERROR":
            return "ERROR", 0

        messages = [
            {
                "role": "system", "content": self.validation_prompt
            },
            {
                "role": "user",
                "images": [image],
                "content": class_label + "\nValidate the class label generated for the given image."
            }
        ] 

        if not "temperature" in kwargs:
            kwargs["temperature"] = 0.0     

        try:
            response = ollama.chat(
                model=self.validation_model,
                messages=messages,
                format=ImageValidationOutputOllama.model_json_schema(),
                options=kwargs
            )

            validation_output = ImageValidationOutputOllama.model_validate_json(response.message.content)

            validation_reasoning = validation_output.validation_reasoning
            confidence = validation_output.confidence

        except Exception as e:
            logger.error(f"Image class label validation failed: {e}")
            validation_reasoning = "ERROR"
            confidence = 0

        return validation_reasoning, confidence

    def generate(self, image_paths: List[str], **kwargs) -> List[Dict]:
        """
        Generates class label for a list of images. 

        Args:
            image_paths (List[str]): List of image paths to generate class labels for.
            **kwargs: Additional arguments to pass to the method for custom pipeline interactions. To control generation parameters for the model.

        Returns:
            List[Dict]: List of class labels, validation reasoning and confidence scores for each image.
        """
        results = super().generate(
            image_paths=image_paths, 
            **kwargs
        )

        return results

__init__(classification_model, validation_model, classification_labels, classification_prompt=BASE_IMAGE_CLASSIFICATION_PROMPT, validation=True, validation_prompt=BASE_IMAGE_CLASSIFICATION_VALIDATION_PROMPT, validation_threshold=0.5, max_retry=3, output_file=None)

Initializes the OllamaForImageClassification pipeline.

Parameters:

Name Type Description Default
classification_model str

Can be either any of the Multimodal (Vision) models supported by Ollama. specific versions of model supported by Ollama.

required
validation_model str

Can be either any of the Multimodal (Vision) models supported by Ollama. specific versions of model supported by Ollama.

required
classification_labels List[str]

List of classification labels to be used for the image classification.

required
classification_prompt str | None

System prompt for classification images. Uses default BASE_IMAGE_CLASSIFICATION_PROMPT prompt if not provided.

BASE_IMAGE_CLASSIFICATION_PROMPT
validation bool

Use validation step or not. Defaults to True.

True
validation_prompt str | None

System prompt for validating image class labels should specify the range of validation score to be generated. Uses default BASE_IMAGE_CLASSIFICATION_PROMPT prompt if not provided.

BASE_IMAGE_CLASSIFICATION_VALIDATION_PROMPT
validation_threshold float

Threshold to determine if image class labels is valid or not should be within specified range for validation score. Defaults to 0.5.

0.5
max_retry int

Number of retries before giving up on the image class labels. Defaults to 3.

3
output_file str | None

Output file path, only JSON is supported for now. Defaults to None.

None
Notes

validation_prompt should specify the rules for validating the class label and the range of validation score to be generated example (0-1). Your validation_threshold should be within this specified range.

It is advised to include class descriptions in the classification_prompt and validation_prompt to help the model understand the context of the class labels. You can also add Few-shot learning examples to the prompt to help the model understand the context of the class labels.

Source code in swiftannotate/image/classification/ollama.py
def __init__(
    self, 
    classification_model: str, 
    validation_model: str,
    classification_labels: List[str],
    classification_prompt: str = BASE_IMAGE_CLASSIFICATION_PROMPT, 
    validation: bool = True,
    validation_prompt: str = BASE_IMAGE_CLASSIFICATION_VALIDATION_PROMPT,
    validation_threshold: float = 0.5,
    max_retry: int = 3, 
    output_file: str | None = None,
):
    """
    Initializes the OllamaForImageClassification pipeline.

    Args:
        classification_model (str): 
            Can be either any of the Multimodal (Vision) models supported by Ollama.
            specific versions of model supported by Ollama.
        validation_model (str): 
            Can be either any of the Multimodal (Vision) models supported by Ollama.
            specific versions of model supported by Ollama.
        classification_labels (List[str]):
            List of classification labels to be used for the image classification.
        classification_prompt (str | None, optional): 
            System prompt for classification images.
            Uses default BASE_IMAGE_CLASSIFICATION_PROMPT prompt if not provided.
        validation (bool, optional): 
            Use validation step or not. Defaults to True.
        validation_prompt (str | None, optional): 
            System prompt for validating image class labels should specify the range of validation score to be generated. 
            Uses default BASE_IMAGE_CLASSIFICATION_PROMPT prompt if not provided.
        validation_threshold (float, optional): 
            Threshold to determine if image class labels is valid or not should be within specified range for validation score. 
            Defaults to 0.5.
        max_retry (int, optional):
            Number of retries before giving up on the image class labels. 
            Defaults to 3.
        output_file (str | None, optional): 
            Output file path, only JSON is supported for now. 
            Defaults to None.

    Notes:
        `validation_prompt` should specify the rules for validating the class label and the range of validation score to be generated example (0-1).
        Your `validation_threshold` should be within this specified range.

        It is advised to include class descriptions in the classification_prompt and validation_prompt to help the model understand the context of the class labels.
        You can also add Few-shot learning examples to the prompt to help the model understand the context of the class labels.
    """

    if not self._validate_ollama_model(classification_model):
        raise ValueError(f"Model {classification_model} is not supported by Ollama.")

    if not self._validate_ollama_model(validation_model):
        raise ValueError(f"Model {validation_model} is not supported by Ollama.")

    self.classification_model = classification_model
    self.validation_model = validation_model

    super().__init__(
        classification_labels=classification_labels,
        classification_prompt=classification_prompt,
        validation=validation,
        validation_prompt=validation_prompt,
        validation_threshold=validation_threshold,
        max_retry=max_retry,
        output_file=output_file
    )

annotate(image, feedback_prompt='', **kwargs)

Annotates the image with a class label. Implements the logic to generate class labels for an image.

Note: The feedback_prompt is dynamically updated using the validation reasoning from the previous iteration in case the calss label does not pass validation threshold.

Parameters:

Name Type Description Default
image str

Base64 encoded image.

required
feedback_prompt str

Feedback prompt for the user to generate a better class label. Defaults to ''.

''
**kwargs

Additional arguments to pass to the method for custom pipeline interactions. To control generation parameters for the model.

{}

Returns:

Name Type Description
str str

Generated class label for the image.

Source code in swiftannotate/image/classification/ollama.py
def annotate(self, image: str, feedback_prompt:str = "", **kwargs) -> str:        
    """
    Annotates the image with a class label. Implements the logic to generate class labels for an image.

    **Note**: The feedback_prompt is dynamically updated using the validation reasoning from 
    the previous iteration in case the calss label does not pass validation threshold.

    Args:
        image (str): Base64 encoded image.
        feedback_prompt (str, optional): Feedback prompt for the user to generate a better class label. Defaults to ''.
        **kwargs: Additional arguments to pass to the method for custom pipeline interactions. To control generation parameters for the model.

    Returns:
        str: Generated class label for the image.
    """
    if feedback_prompt:
        user_prompt = f"""
            Last time the class label you generated for this image was incorrect because of the following reasons:
            {feedback_prompt}

            Regenerate the class label for the given image.
            Classify the given image as {', '.join(map(str, self.classification_labels))}
        """
    else:
        user_prompt = f"Classify the given image as {', '.join(map(str, self.classification_labels))}"

    messages=[
        {"role": "system", "content": self.classification_prompt},
        {
            "role": "user",
            "images": [image],
            "content": user_prompt,
        }
    ]

    if not "temperature" in kwargs:
        kwargs["temperature"] = 0.0

    try:  
        response = ollama.chat(
            model=self.classification_model,
            messages=messages,
            format=ImageClassificationOutputOllama.model_json_schema(),
            options=kwargs
        )

        output = ImageClassificationOutputOllama.model_validate_json(response.message.content)
        class_label = output.class_label.lower()

    except Exception as e:
        logger.error(f"Image classification failed: {e}")
        class_label = "ERROR"

    return class_label

generate(image_paths, **kwargs)

Generates class label for a list of images.

Parameters:

Name Type Description Default
image_paths List[str]

List of image paths to generate class labels for.

required
**kwargs

Additional arguments to pass to the method for custom pipeline interactions. To control generation parameters for the model.

{}

Returns:

Type Description
List[Dict]

List[Dict]: List of class labels, validation reasoning and confidence scores for each image.

Source code in swiftannotate/image/classification/ollama.py
def generate(self, image_paths: List[str], **kwargs) -> List[Dict]:
    """
    Generates class label for a list of images. 

    Args:
        image_paths (List[str]): List of image paths to generate class labels for.
        **kwargs: Additional arguments to pass to the method for custom pipeline interactions. To control generation parameters for the model.

    Returns:
        List[Dict]: List of class labels, validation reasoning and confidence scores for each image.
    """
    results = super().generate(
        image_paths=image_paths, 
        **kwargs
    )

    return results

validate(image, class_label, **kwargs)

Validates the class label generated for the image.

Parameters:

Name Type Description Default
image str

Base64 encoded image.

required
class_label str

Class Label generated for the image.

required

Returns:

Type Description
Tuple[str, float]

Tuple[str, float]: Validation reasoning and confidence score for the class label.

Source code in swiftannotate/image/classification/ollama.py
def validate(self, image: str, class_label: str, **kwargs) -> Tuple[str, float]: 
    """
    Validates the class label generated for the image.

    Args:
        image (str): Base64 encoded image.
        class_label (str): Class Label generated for the image.

    Returns:
        Tuple[str, float]: Validation reasoning and confidence score for the class label.
    """
    if class_label == "ERROR":
        return "ERROR", 0

    messages = [
        {
            "role": "system", "content": self.validation_prompt
        },
        {
            "role": "user",
            "images": [image],
            "content": class_label + "\nValidate the class label generated for the given image."
        }
    ] 

    if not "temperature" in kwargs:
        kwargs["temperature"] = 0.0     

    try:
        response = ollama.chat(
            model=self.validation_model,
            messages=messages,
            format=ImageValidationOutputOllama.model_json_schema(),
            options=kwargs
        )

        validation_output = ImageValidationOutputOllama.model_validate_json(response.message.content)

        validation_reasoning = validation_output.validation_reasoning
        confidence = validation_output.confidence

    except Exception as e:
        logger.error(f"Image class label validation failed: {e}")
        validation_reasoning = "ERROR"
        confidence = 0

    return validation_reasoning, confidence

OpenAIForImageClassification

Bases: BaseImageClassification

OpenAIForImageClassification pipeline using OpenAI API.

Example usage:

from swiftannotate.image import OpenAIForImageClassification

# Initialize the pipeline
classification_pipeline = OpenAIForImageClassification(
    classification_model="gpt-4o",
    validation_model="gpt-4o-mini",
    api_key="your_api_key_here",
    classification_labels=["kitchen", "bedroom", "living room"],
    output_file="captions.json"
)

# Generate captions for a list of images
image_paths = ["path/to/image1.jpg"]
results = classification_pipeline.generate(image_paths)

# Print results
# Output: [
#     {
#         "image_path": 'path/to/image1.jpg', 
#         "image_classification": 'kitchen', 
#         "validation_reasoning": 'The class label is valid.', 
#         "validation_score": 0.6
#     }, 
# ]
Source code in swiftannotate/image/classification/openai.py
class OpenAIForImageClassification(BaseImageClassification):
    """
    OpenAIForImageClassification pipeline using OpenAI API.

    Example usage:

    ```python
    from swiftannotate.image import OpenAIForImageClassification

    # Initialize the pipeline
    classification_pipeline = OpenAIForImageClassification(
        classification_model="gpt-4o",
        validation_model="gpt-4o-mini",
        api_key="your_api_key_here",
        classification_labels=["kitchen", "bedroom", "living room"],
        output_file="captions.json"
    )

    # Generate captions for a list of images
    image_paths = ["path/to/image1.jpg"]
    results = classification_pipeline.generate(image_paths)

    # Print results
    # Output: [
    #     {
    #         "image_path": 'path/to/image1.jpg', 
    #         "image_classification": 'kitchen', 
    #         "validation_reasoning": 'The class label is valid.', 
    #         "validation_score": 0.6
    #     }, 
    # ]
    ```
    """
    def __init__(
        self, 
        classification_model: str, 
        validation_model: str,
        api_key: str, 
        classification_labels: List[str],
        classification_prompt: str = BASE_IMAGE_CLASSIFICATION_PROMPT, 
        validation: bool = True,
        validation_prompt: str = BASE_IMAGE_CLASSIFICATION_VALIDATION_PROMPT,
        validation_threshold: float = 0.5,
        max_retry: int = 3, 
        output_file: str | None = None,
        **kwargs
    ):
        """
        Initializes the OpenAIForImageClassification pipeline.

        Args:
            classification_model (str): 
                Can be either "gpt-4o", "gpt-4o-mini", etc. or 
                specific versions of model supported by OpenAI.
            validation_model (str): 
                Can be either "gpt-4o", "gpt-4o-mini", etc. or 
                specific versions of model supported by OpenAI.
            api_key (str): OpenAI API key.
            classification_labels (List[str]):
                List of classification labels to be used for the image classification.
            classification_prompt (str | None, optional): 
                System prompt for classification images.
                Uses default BASE_IMAGE_CLASSIFICATION_PROMPT prompt if not provided.
            validation (bool, optional): 
                Use validation step or not. Defaults to True.
            validation_prompt (str | None, optional): 
                System prompt for validating image class labels should specify the range of validation score to be generated. 
                Uses default BASE_IMAGE_CLASSIFICATION_PROMPT prompt if not provided.
            validation_threshold (float, optional): 
                Threshold to determine if image class labels is valid or not should be within specified range for validation score. 
                Defaults to 0.5.
            max_retry (int, optional):
                Number of retries before giving up on the image class labels. 
                Defaults to 3.
            output_file (str | None, optional): 
                Output file path, only JSON is supported for now. 
                Defaults to None.

        Keyword Arguments:
            detail (str, optional): 
                Specific to OpenAI. Detail level of the image (Higher resolution costs more). Defaults to "low".

        Notes:
            `validation_prompt` should specify the rules for validating the class label and the range of validation score to be generated example (0-1).
            Your `validation_threshold` should be within this specified range.

            It is advised to include class descriptions in the classification_prompt and validation_prompt to help the model understand the context of the class labels.
            You can also add Few-shot learning examples to the prompt to help the model understand the context of the class labels.
        """
        self.classification_model = classification_model
        self.validation_model = validation_model
        self.client = OpenAI(api_key)

        super().__init__(
            classification_labels=classification_labels,
            classification_prompt=classification_prompt,
            validation=validation,
            validation_prompt=validation_prompt,
            validation_threshold=validation_threshold,
            max_retry=max_retry,
            output_file=output_file
        )

        self.detail = kwargs.get("detail", "low")

    def annotate(self, image: str, feedback_prompt:str = "", **kwargs) -> str:        
        """
        Annotates the image with a class label. Implements the logic to generate class labels for an image.

        **Note**: The feedback_prompt is dynamically updated using the validation reasoning from 
        the previous iteration in case the calss label does not pass validation threshold.

        Args:
            image (str): Base64 encoded image.
            feedback_prompt (str, optional): Feedback prompt for the user to generate a better class label. Defaults to ''.
            **kwargs: Additional arguments to pass to the method for custom pipeline interactions. To control generation parameters for the model.

        Returns:
            str: Generated class label for the image.
        """
        if feedback_prompt:
            user_prompt = f"""
                Last time the class label you generated for this image was incorrect because of the following reasons:
                {feedback_prompt}

                Regenerate the class label for the given image.
                Classify the given image as {', '.join(map(str, self.classification_labels))}
            """
        else:
            user_prompt = f"Classify the given image as {', '.join(map(str, self.classification_labels))}"

        messages=[
            {"role": "system", "content": self.classification_prompt},
            {
                "role": "user",
                "content": [
                    {
                        "type": "image_url",
                        "image_url": {
                            "url": f"data:image/jpeg;base64,{image}",
                            "detail": self.detail
                        },
                    },
                    {"type": "text", "text": user_prompt},
                ]
            }
        ]

        try:
            response = self.client.chat.completions.create(
                model=self.classification_model,
                messages=messages,
                response_format=ImageClassificationOutputOpenAI,
                **kwargs
            )
            output = response.choices[0].message.parsed
            class_label = output.class_label.lower()
        except Exception as e:
            logger.error(f"Image classification failed: {e}")
            class_label = "ERROR"

        return class_label

    def validate(self, image: str, class_label: str, **kwargs) -> Tuple[str, float]: 
        """
        Validates the class label generated for the image.

        Args:
            image (str): Base64 encoded image.
            class_label (str): Class Label generated for the image.

        Returns:
            Tuple[str, float]: Validation reasoning and confidence score for the class label.
        """
        if class_label == "ERROR":
            return "ERROR", 0

        messages = [
            {
                "role": "system",
                "content": self.validation_prompt
            },
            {
                "role": "user",
                "content": [
                    {
                        "type": "image_url",
                        "image_url": {
                            "url": f"data:image/jpeg;base64,{image}",
                            "detail": self.detail
                        },
                    },
                    {
                        "type": "text",
                        "text": class_label + "\nValidate the class label generated for the given image."
                    }
                ]
            }
        ]      

        try:
            response = self.client.chat.completions.create(
                model=self.validation_model,
                messages=messages,
                response_format=ImageValidationOutputOpenAI,
                **kwargs
            )
            validation_output = response.choices[0].message.parsed
            validation_reasoning = validation_output.validation_reasoning
            confidence = validation_output.confidence

        except Exception as e:
            logger.error(f"Image class label validation failed: {e}")
            validation_reasoning = "ERROR"
            confidence = 0

        return validation_reasoning, confidence

    def generate(self, image_paths: List[str], **kwargs) -> List[Dict]:
        """
        Generates class label for a list of images. 

        Args:
            image_paths (List[str]): List of image paths to generate class labels for.
            **kwargs: Additional arguments to pass to the method for custom pipeline interactions. To control generation parameters for the model.

        Returns:
            List[Dict]: List of class labels, validation reasoning and confidence scores for each image.
        """
        results = super().generate(
            image_paths=image_paths, 
            **kwargs
        )

        return results

__init__(classification_model, validation_model, api_key, classification_labels, classification_prompt=BASE_IMAGE_CLASSIFICATION_PROMPT, validation=True, validation_prompt=BASE_IMAGE_CLASSIFICATION_VALIDATION_PROMPT, validation_threshold=0.5, max_retry=3, output_file=None, **kwargs)

Initializes the OpenAIForImageClassification pipeline.

Parameters:

Name Type Description Default
classification_model str

Can be either "gpt-4o", "gpt-4o-mini", etc. or specific versions of model supported by OpenAI.

required
validation_model str

Can be either "gpt-4o", "gpt-4o-mini", etc. or specific versions of model supported by OpenAI.

required
api_key str

OpenAI API key.

required
classification_labels List[str]

List of classification labels to be used for the image classification.

required
classification_prompt str | None

System prompt for classification images. Uses default BASE_IMAGE_CLASSIFICATION_PROMPT prompt if not provided.

BASE_IMAGE_CLASSIFICATION_PROMPT
validation bool

Use validation step or not. Defaults to True.

True
validation_prompt str | None

System prompt for validating image class labels should specify the range of validation score to be generated. Uses default BASE_IMAGE_CLASSIFICATION_PROMPT prompt if not provided.

BASE_IMAGE_CLASSIFICATION_VALIDATION_PROMPT
validation_threshold float

Threshold to determine if image class labels is valid or not should be within specified range for validation score. Defaults to 0.5.

0.5
max_retry int

Number of retries before giving up on the image class labels. Defaults to 3.

3
output_file str | None

Output file path, only JSON is supported for now. Defaults to None.

None

Other Parameters:

Name Type Description
detail str

Specific to OpenAI. Detail level of the image (Higher resolution costs more). Defaults to "low".

Notes

validation_prompt should specify the rules for validating the class label and the range of validation score to be generated example (0-1). Your validation_threshold should be within this specified range.

It is advised to include class descriptions in the classification_prompt and validation_prompt to help the model understand the context of the class labels. You can also add Few-shot learning examples to the prompt to help the model understand the context of the class labels.

Source code in swiftannotate/image/classification/openai.py
def __init__(
    self, 
    classification_model: str, 
    validation_model: str,
    api_key: str, 
    classification_labels: List[str],
    classification_prompt: str = BASE_IMAGE_CLASSIFICATION_PROMPT, 
    validation: bool = True,
    validation_prompt: str = BASE_IMAGE_CLASSIFICATION_VALIDATION_PROMPT,
    validation_threshold: float = 0.5,
    max_retry: int = 3, 
    output_file: str | None = None,
    **kwargs
):
    """
    Initializes the OpenAIForImageClassification pipeline.

    Args:
        classification_model (str): 
            Can be either "gpt-4o", "gpt-4o-mini", etc. or 
            specific versions of model supported by OpenAI.
        validation_model (str): 
            Can be either "gpt-4o", "gpt-4o-mini", etc. or 
            specific versions of model supported by OpenAI.
        api_key (str): OpenAI API key.
        classification_labels (List[str]):
            List of classification labels to be used for the image classification.
        classification_prompt (str | None, optional): 
            System prompt for classification images.
            Uses default BASE_IMAGE_CLASSIFICATION_PROMPT prompt if not provided.
        validation (bool, optional): 
            Use validation step or not. Defaults to True.
        validation_prompt (str | None, optional): 
            System prompt for validating image class labels should specify the range of validation score to be generated. 
            Uses default BASE_IMAGE_CLASSIFICATION_PROMPT prompt if not provided.
        validation_threshold (float, optional): 
            Threshold to determine if image class labels is valid or not should be within specified range for validation score. 
            Defaults to 0.5.
        max_retry (int, optional):
            Number of retries before giving up on the image class labels. 
            Defaults to 3.
        output_file (str | None, optional): 
            Output file path, only JSON is supported for now. 
            Defaults to None.

    Keyword Arguments:
        detail (str, optional): 
            Specific to OpenAI. Detail level of the image (Higher resolution costs more). Defaults to "low".

    Notes:
        `validation_prompt` should specify the rules for validating the class label and the range of validation score to be generated example (0-1).
        Your `validation_threshold` should be within this specified range.

        It is advised to include class descriptions in the classification_prompt and validation_prompt to help the model understand the context of the class labels.
        You can also add Few-shot learning examples to the prompt to help the model understand the context of the class labels.
    """
    self.classification_model = classification_model
    self.validation_model = validation_model
    self.client = OpenAI(api_key)

    super().__init__(
        classification_labels=classification_labels,
        classification_prompt=classification_prompt,
        validation=validation,
        validation_prompt=validation_prompt,
        validation_threshold=validation_threshold,
        max_retry=max_retry,
        output_file=output_file
    )

    self.detail = kwargs.get("detail", "low")

annotate(image, feedback_prompt='', **kwargs)

Annotates the image with a class label. Implements the logic to generate class labels for an image.

Note: The feedback_prompt is dynamically updated using the validation reasoning from the previous iteration in case the calss label does not pass validation threshold.

Parameters:

Name Type Description Default
image str

Base64 encoded image.

required
feedback_prompt str

Feedback prompt for the user to generate a better class label. Defaults to ''.

''
**kwargs

Additional arguments to pass to the method for custom pipeline interactions. To control generation parameters for the model.

{}

Returns:

Name Type Description
str str

Generated class label for the image.

Source code in swiftannotate/image/classification/openai.py
def annotate(self, image: str, feedback_prompt:str = "", **kwargs) -> str:        
    """
    Annotates the image with a class label. Implements the logic to generate class labels for an image.

    **Note**: The feedback_prompt is dynamically updated using the validation reasoning from 
    the previous iteration in case the calss label does not pass validation threshold.

    Args:
        image (str): Base64 encoded image.
        feedback_prompt (str, optional): Feedback prompt for the user to generate a better class label. Defaults to ''.
        **kwargs: Additional arguments to pass to the method for custom pipeline interactions. To control generation parameters for the model.

    Returns:
        str: Generated class label for the image.
    """
    if feedback_prompt:
        user_prompt = f"""
            Last time the class label you generated for this image was incorrect because of the following reasons:
            {feedback_prompt}

            Regenerate the class label for the given image.
            Classify the given image as {', '.join(map(str, self.classification_labels))}
        """
    else:
        user_prompt = f"Classify the given image as {', '.join(map(str, self.classification_labels))}"

    messages=[
        {"role": "system", "content": self.classification_prompt},
        {
            "role": "user",
            "content": [
                {
                    "type": "image_url",
                    "image_url": {
                        "url": f"data:image/jpeg;base64,{image}",
                        "detail": self.detail
                    },
                },
                {"type": "text", "text": user_prompt},
            ]
        }
    ]

    try:
        response = self.client.chat.completions.create(
            model=self.classification_model,
            messages=messages,
            response_format=ImageClassificationOutputOpenAI,
            **kwargs
        )
        output = response.choices[0].message.parsed
        class_label = output.class_label.lower()
    except Exception as e:
        logger.error(f"Image classification failed: {e}")
        class_label = "ERROR"

    return class_label

generate(image_paths, **kwargs)

Generates class label for a list of images.

Parameters:

Name Type Description Default
image_paths List[str]

List of image paths to generate class labels for.

required
**kwargs

Additional arguments to pass to the method for custom pipeline interactions. To control generation parameters for the model.

{}

Returns:

Type Description
List[Dict]

List[Dict]: List of class labels, validation reasoning and confidence scores for each image.

Source code in swiftannotate/image/classification/openai.py
def generate(self, image_paths: List[str], **kwargs) -> List[Dict]:
    """
    Generates class label for a list of images. 

    Args:
        image_paths (List[str]): List of image paths to generate class labels for.
        **kwargs: Additional arguments to pass to the method for custom pipeline interactions. To control generation parameters for the model.

    Returns:
        List[Dict]: List of class labels, validation reasoning and confidence scores for each image.
    """
    results = super().generate(
        image_paths=image_paths, 
        **kwargs
    )

    return results

validate(image, class_label, **kwargs)

Validates the class label generated for the image.

Parameters:

Name Type Description Default
image str

Base64 encoded image.

required
class_label str

Class Label generated for the image.

required

Returns:

Type Description
Tuple[str, float]

Tuple[str, float]: Validation reasoning and confidence score for the class label.

Source code in swiftannotate/image/classification/openai.py
def validate(self, image: str, class_label: str, **kwargs) -> Tuple[str, float]: 
    """
    Validates the class label generated for the image.

    Args:
        image (str): Base64 encoded image.
        class_label (str): Class Label generated for the image.

    Returns:
        Tuple[str, float]: Validation reasoning and confidence score for the class label.
    """
    if class_label == "ERROR":
        return "ERROR", 0

    messages = [
        {
            "role": "system",
            "content": self.validation_prompt
        },
        {
            "role": "user",
            "content": [
                {
                    "type": "image_url",
                    "image_url": {
                        "url": f"data:image/jpeg;base64,{image}",
                        "detail": self.detail
                    },
                },
                {
                    "type": "text",
                    "text": class_label + "\nValidate the class label generated for the given image."
                }
            ]
        }
    ]      

    try:
        response = self.client.chat.completions.create(
            model=self.validation_model,
            messages=messages,
            response_format=ImageValidationOutputOpenAI,
            **kwargs
        )
        validation_output = response.choices[0].message.parsed
        validation_reasoning = validation_output.validation_reasoning
        confidence = validation_output.confidence

    except Exception as e:
        logger.error(f"Image class label validation failed: {e}")
        validation_reasoning = "ERROR"
        confidence = 0

    return validation_reasoning, confidence

Qwen2VLForImageClassification

Bases: BaseImageClassification

Qwen2VLForImageClassification pipeline using Qwen2VL model.

Example usage:

from transformers import AutoProcessor, AutoModelForImageTextToText
from transformers import BitsAndBytesConfig
from swiftannotate.image import Qwen2VLForImageClassification

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype="float16",
    bnb_4bit_use_double_quant=True
)

model = AutoModelForImageTextToText.from_pretrained(
    "Qwen/Qwen2-VL-7B-Instruct",
    device_map="auto",
    torch_dtype="auto",
    quantization_config=quantization_config)

processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")

# Load the Caption Model
kwargs = {"temperature": 0}
classification_pipeline = Qwen2VLForImageClassification(
    model=model,
    processor=processor,
    classification_labels=["kitchen", "bottle", "none"],
    output_file="output.json",
)

# Generate captions for images
image_paths = ['path/to/image1.jpg']
results = classification_pipeline.generate(image_paths, **kwargs)

# Print results
# Output: [
#     {
#         "image_path": 'path/to/image1.jpg', 
#         "image_classification": 'kitchen', 
#         "validation_reasoning": 'The class label is valid.', 
#         "validation_score": 0.6
#     }, 
# ]

Source code in swiftannotate/image/classification/qwen.py
class Qwen2VLForImageClassification(BaseImageClassification):
    """
    Qwen2VLForImageClassification pipeline using Qwen2VL model.

    Example usage:
    ```python
    from transformers import AutoProcessor, AutoModelForImageTextToText
    from transformers import BitsAndBytesConfig
    from swiftannotate.image import Qwen2VLForImageClassification

    quantization_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype="float16",
        bnb_4bit_use_double_quant=True
    )

    model = AutoModelForImageTextToText.from_pretrained(
        "Qwen/Qwen2-VL-7B-Instruct",
        device_map="auto",
        torch_dtype="auto",
        quantization_config=quantization_config)

    processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")

    # Load the Caption Model
    kwargs = {"temperature": 0}
    classification_pipeline = Qwen2VLForImageClassification(
        model=model,
        processor=processor,
        classification_labels=["kitchen", "bottle", "none"],
        output_file="output.json",
    )

    # Generate captions for images
    image_paths = ['path/to/image1.jpg']
    results = classification_pipeline.generate(image_paths, **kwargs)

    # Print results
    # Output: [
    #     {
    #         "image_path": 'path/to/image1.jpg', 
    #         "image_classification": 'kitchen', 
    #         "validation_reasoning": 'The class label is valid.', 
    #         "validation_score": 0.6
    #     }, 
    # ]
    ```
    """
    def __init__(
        self, 
        model: AutoModelForImageTextToText | Qwen2VLForConditionalGeneration, 
        processor: AutoProcessor | Qwen2VLProcessor,
        classification_labels: List[str],
        classification_prompt: str = BASE_IMAGE_CLASSIFICATION_PROMPT, 
        validation: bool = True,
        validation_prompt: str = BASE_IMAGE_CLASSIFICATION_VALIDATION_PROMPT,
        validation_threshold: float = 0.5,
        max_retry: int = 3, 
        output_file: str | None = None,
        **kwargs
    ):     
        """
        Initializes the Qwen2VLForImageClassification pipeline.

        Args:
            model (AutoModelForImageTextToText): 
                Model for image classification. Should be an instance of AutoModelForImageTextToText with Qwen2-VL pretrained weights.
                Can be any version of Qwen2-VL model (7B, 72B).
            processor (AutoProcessor): 
                Processor for the Qwen2-VL model. Should be an instance of AutoProcessor.
            classification_labels (List[str]):
                List of classification labels to be used for the image classification.
            classification_prompt (str | None, optional): 
                System prompt for classification images.
                Uses default BASE_IMAGE_CLASSIFICATION_PROMPT prompt if not provided.
            validation (bool, optional): 
                Use validation step or not. Defaults to True.
            validation_prompt (str | None, optional): 
                System prompt for validating image class labels should specify the range of validation score to be generated. 
                Uses default BASE_IMAGE_CLASSIFICATION_PROMPT prompt if not provided.
            validation_threshold (float, optional): 
                Threshold to determine if image class labels is valid or not should be within specified range for validation score. 
                Defaults to 0.5.
            max_retry (int, optional):
                Number of retries before giving up on the image class labels. 
                Defaults to 3.
            output_file (str | None, optional): 
                Output file path, only JSON is supported for now. 
                Defaults to None.

        Keyword Arguments:
            resize_height (int, optional):
                Height to resize the image before generating class labels. Defaults to 280.
            resize_width (int, optional):
                Width to resize the image before generating class labels. Defaults to 420.

        Notes:
            `validation_prompt` should specify the rules for validating the class label and the range of validation score to be generated example (0-1).
            Your `validation_threshold` should be within this specified range.

            It is advised to include class descriptions in the classification_prompt and validation_prompt to help the model understand the context of the class labels.
            You can also add Few-shot learning examples to the prompt to help the model understand the context of the class labels.
        """    

        if not isinstance(model, Qwen2VLForConditionalGeneration):
            raise ValueError("Model should be an instance of Qwen2VLForConditionalGeneration.")
        if not isinstance(processor, Qwen2VLProcessor):
            raise ValueError("Processor should be an instance of Qwen2VLProcessor.")

        self.model = model
        self.processor = processor

        super().__init__(
            classification_labels=classification_labels,
            classification_prompt=classification_prompt,
            validation=validation,
            validation_prompt=validation_prompt,
            validation_threshold=validation_threshold,
            max_retry=max_retry,
            output_file=output_file
        )

        self.resize_height = kwargs.get("resize_height", 280)
        self.resize_width = kwargs.get("resize_width", 420)

    def annotate(self, image: str, feedback_prompt:str = "", **kwargs) -> str:
        """
        Annotates the image with a class label. Implements the logic to generate class labels for an image.

        **Note**: The feedback_prompt is dynamically updated using the validation reasoning from 
        the previous iteration in case the calss label does not pass validation threshold.

        Args:
            image (str): Base64 encoded image.
            feedback_prompt (str, optional): Feedback prompt for the user to generate a better class label. Defaults to ''.
            **kwargs: Additional arguments to pass to the method for custom pipeline interactions. To control generation parameters for the model.

        Returns:
            str: Generated class label for the image.
        """
        if feedback_prompt:
            user_prompt = f"""
                Last time the class label you generated for this image was incorrect because of the following reasons:
                {feedback_prompt}

                Regenerate the class label for the given image.
                Classify the given image as {', '.join(map(str, self.classification_labels))}
                Return output as a JSON object with key as 'class_label'
            """
        else:
            user_prompt = f"Classify the given image as {', '.join(map(str, self.classification_labels))} \nReturn output as a JSON object with key as 'class_label'"

        messages = [
            {"role": "system", "content": self.classification_prompt},
            {
                "role": "user", 
                "content": [
                    {
                        "type": "image", 
                        "image": f"data:image;base64,{image}",
                        "resized_height": self.resize_height,
                        "resized_width": self.resize_width,
                    },
                    {"type": "text", "text": user_prompt},
                ],
            },
        ]

        text = self.processor.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )

        image_inputs, video_inputs = process_vision_info(messages)
        inputs = self.processor(
            text=[text],
            images=image_inputs,
            videos=video_inputs,
            padding=True,
            return_tensors="pt",
        )
        inputs = inputs.to(self.model.device)

        # Inference: Generation of the output
        if "max_new_tokens" not in kwargs:
            kwargs["max_new_tokens"] = 512

        generated_ids = self.model.generate(**inputs, **kwargs)
        generated_ids_trimmed = [
            out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
        ]
        class_label = self.processor.batch_decode(
            generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )[0]

        try:
            class_label = class_label.replace('```', '').replace('json', '')
            class_label = json.loads(class_label)
            class_label = class_label["class_label"].lower()
        except Exception as e:
            logger.error(f"Image classification parsing failed trying to parse using another logic.")
            potential_class_labels = [label.lower() for label in class_label.split() if label in self.classification_labels]
            class_label = potential_class_labels[0] if potential_class_labels else "ERROR"

        return class_label

    def validate(self, image: str, class_label: str, **kwargs) -> Tuple[str, float]:
        """
        Validates the class label generated for the image.

        Args:
            image (str): Base64 encoded image.
            class_label (str): Class Label generated for the image.

        Returns:
            Tuple[str, float]: Validation reasoning and confidence score for the class label.
        """
        messages = [
            {"role": "system", "content": self.validation_prompt},
            {
                "role": "user", 
                "content": [
                    {
                        "type": "image", 
                        "image": f"data:image;base64,{image}",
                        "resized_height": self.resize_height,
                        "resized_width": self.resize_width,
                    },
                    {"type": "text", "text": class_label},
                    {
                        "type": "text", 
                        "text": """
                        Validate the class label generated for the given image. 
                        Return output as a JSON object with keys as 'validation_reasoning' and 'confidence'.
                        """
                    },
                ],
            },
        ]

        text = self.processor.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )

        image_inputs, video_inputs = process_vision_info(messages)
        inputs = self.processor(
            text=[text],
            images=image_inputs,
            videos=video_inputs,
            padding=True,
            return_tensors="pt",
        )
        inputs = inputs.to(self.model.device)

        # Inference: Generation of the output
        if "max_new_tokens" not in kwargs:
            kwargs["max_new_tokens"] = 512

        generated_ids = self.model.generate(**inputs, **kwargs)
        generated_ids_trimmed = [
            out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
        ]
        validation_output = self.processor.batch_decode(
            generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )[0]

        # TODO: Need a better way to parse the output
        try:
            validation_output = validation_output.replace('```', '').replace('json', '')
            validation_output = json.loads(validation_output)
            validation_reasoning = validation_output["validation_reasoning"]
            confidence = validation_output["confidence"]
        except Exception as e:
            logger.error(f"Image class label validation parsing failed trying to parse using another logic.")

            number_str  = ''.join((ch if ch in '0123456789.-e' else ' ') for ch in validation_output)
            number_str = [i for i in number_str.split() if i.isalnum()]
            potential_confidence_scores = [float(i) for i in number_str if float(i) >= 0 and float(i) <= 1]
            confidence = max(potential_confidence_scores) if potential_confidence_scores else 0.0
            validation_reasoning = validation_output

        return validation_reasoning, confidence

    def generate(self, image_paths: List[str], **kwargs) -> List[Dict]:
        """
        Generates class label for a list of images. 

        Args:
            image_paths (List[str]): List of image paths to generate class labels for.
            **kwargs: Additional arguments to pass to the method for custom pipeline interactions. To control generation parameters for the model.

        Returns:
            List[Dict]: List of class labels, validation reasoning and confidence scores for each image.
        """
        results = super().generate(
            image_paths=image_paths, 
            **kwargs
        )

        return results

__init__(model, processor, classification_labels, classification_prompt=BASE_IMAGE_CLASSIFICATION_PROMPT, validation=True, validation_prompt=BASE_IMAGE_CLASSIFICATION_VALIDATION_PROMPT, validation_threshold=0.5, max_retry=3, output_file=None, **kwargs)

Initializes the Qwen2VLForImageClassification pipeline.

Parameters:

Name Type Description Default
model AutoModelForImageTextToText

Model for image classification. Should be an instance of AutoModelForImageTextToText with Qwen2-VL pretrained weights. Can be any version of Qwen2-VL model (7B, 72B).

required
processor AutoProcessor

Processor for the Qwen2-VL model. Should be an instance of AutoProcessor.

required
classification_labels List[str]

List of classification labels to be used for the image classification.

required
classification_prompt str | None

System prompt for classification images. Uses default BASE_IMAGE_CLASSIFICATION_PROMPT prompt if not provided.

BASE_IMAGE_CLASSIFICATION_PROMPT
validation bool

Use validation step or not. Defaults to True.

True
validation_prompt str | None

System prompt for validating image class labels should specify the range of validation score to be generated. Uses default BASE_IMAGE_CLASSIFICATION_PROMPT prompt if not provided.

BASE_IMAGE_CLASSIFICATION_VALIDATION_PROMPT
validation_threshold float

Threshold to determine if image class labels is valid or not should be within specified range for validation score. Defaults to 0.5.

0.5
max_retry int

Number of retries before giving up on the image class labels. Defaults to 3.

3
output_file str | None

Output file path, only JSON is supported for now. Defaults to None.

None

Other Parameters:

Name Type Description
resize_height int

Height to resize the image before generating class labels. Defaults to 280.

resize_width int

Width to resize the image before generating class labels. Defaults to 420.

Notes

validation_prompt should specify the rules for validating the class label and the range of validation score to be generated example (0-1). Your validation_threshold should be within this specified range.

It is advised to include class descriptions in the classification_prompt and validation_prompt to help the model understand the context of the class labels. You can also add Few-shot learning examples to the prompt to help the model understand the context of the class labels.

Source code in swiftannotate/image/classification/qwen.py
def __init__(
    self, 
    model: AutoModelForImageTextToText | Qwen2VLForConditionalGeneration, 
    processor: AutoProcessor | Qwen2VLProcessor,
    classification_labels: List[str],
    classification_prompt: str = BASE_IMAGE_CLASSIFICATION_PROMPT, 
    validation: bool = True,
    validation_prompt: str = BASE_IMAGE_CLASSIFICATION_VALIDATION_PROMPT,
    validation_threshold: float = 0.5,
    max_retry: int = 3, 
    output_file: str | None = None,
    **kwargs
):     
    """
    Initializes the Qwen2VLForImageClassification pipeline.

    Args:
        model (AutoModelForImageTextToText): 
            Model for image classification. Should be an instance of AutoModelForImageTextToText with Qwen2-VL pretrained weights.
            Can be any version of Qwen2-VL model (7B, 72B).
        processor (AutoProcessor): 
            Processor for the Qwen2-VL model. Should be an instance of AutoProcessor.
        classification_labels (List[str]):
            List of classification labels to be used for the image classification.
        classification_prompt (str | None, optional): 
            System prompt for classification images.
            Uses default BASE_IMAGE_CLASSIFICATION_PROMPT prompt if not provided.
        validation (bool, optional): 
            Use validation step or not. Defaults to True.
        validation_prompt (str | None, optional): 
            System prompt for validating image class labels should specify the range of validation score to be generated. 
            Uses default BASE_IMAGE_CLASSIFICATION_PROMPT prompt if not provided.
        validation_threshold (float, optional): 
            Threshold to determine if image class labels is valid or not should be within specified range for validation score. 
            Defaults to 0.5.
        max_retry (int, optional):
            Number of retries before giving up on the image class labels. 
            Defaults to 3.
        output_file (str | None, optional): 
            Output file path, only JSON is supported for now. 
            Defaults to None.

    Keyword Arguments:
        resize_height (int, optional):
            Height to resize the image before generating class labels. Defaults to 280.
        resize_width (int, optional):
            Width to resize the image before generating class labels. Defaults to 420.

    Notes:
        `validation_prompt` should specify the rules for validating the class label and the range of validation score to be generated example (0-1).
        Your `validation_threshold` should be within this specified range.

        It is advised to include class descriptions in the classification_prompt and validation_prompt to help the model understand the context of the class labels.
        You can also add Few-shot learning examples to the prompt to help the model understand the context of the class labels.
    """    

    if not isinstance(model, Qwen2VLForConditionalGeneration):
        raise ValueError("Model should be an instance of Qwen2VLForConditionalGeneration.")
    if not isinstance(processor, Qwen2VLProcessor):
        raise ValueError("Processor should be an instance of Qwen2VLProcessor.")

    self.model = model
    self.processor = processor

    super().__init__(
        classification_labels=classification_labels,
        classification_prompt=classification_prompt,
        validation=validation,
        validation_prompt=validation_prompt,
        validation_threshold=validation_threshold,
        max_retry=max_retry,
        output_file=output_file
    )

    self.resize_height = kwargs.get("resize_height", 280)
    self.resize_width = kwargs.get("resize_width", 420)

annotate(image, feedback_prompt='', **kwargs)

Annotates the image with a class label. Implements the logic to generate class labels for an image.

Note: The feedback_prompt is dynamically updated using the validation reasoning from the previous iteration in case the calss label does not pass validation threshold.

Parameters:

Name Type Description Default
image str

Base64 encoded image.

required
feedback_prompt str

Feedback prompt for the user to generate a better class label. Defaults to ''.

''
**kwargs

Additional arguments to pass to the method for custom pipeline interactions. To control generation parameters for the model.

{}

Returns:

Name Type Description
str str

Generated class label for the image.

Source code in swiftannotate/image/classification/qwen.py
def annotate(self, image: str, feedback_prompt:str = "", **kwargs) -> str:
    """
    Annotates the image with a class label. Implements the logic to generate class labels for an image.

    **Note**: The feedback_prompt is dynamically updated using the validation reasoning from 
    the previous iteration in case the calss label does not pass validation threshold.

    Args:
        image (str): Base64 encoded image.
        feedback_prompt (str, optional): Feedback prompt for the user to generate a better class label. Defaults to ''.
        **kwargs: Additional arguments to pass to the method for custom pipeline interactions. To control generation parameters for the model.

    Returns:
        str: Generated class label for the image.
    """
    if feedback_prompt:
        user_prompt = f"""
            Last time the class label you generated for this image was incorrect because of the following reasons:
            {feedback_prompt}

            Regenerate the class label for the given image.
            Classify the given image as {', '.join(map(str, self.classification_labels))}
            Return output as a JSON object with key as 'class_label'
        """
    else:
        user_prompt = f"Classify the given image as {', '.join(map(str, self.classification_labels))} \nReturn output as a JSON object with key as 'class_label'"

    messages = [
        {"role": "system", "content": self.classification_prompt},
        {
            "role": "user", 
            "content": [
                {
                    "type": "image", 
                    "image": f"data:image;base64,{image}",
                    "resized_height": self.resize_height,
                    "resized_width": self.resize_width,
                },
                {"type": "text", "text": user_prompt},
            ],
        },
    ]

    text = self.processor.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )

    image_inputs, video_inputs = process_vision_info(messages)
    inputs = self.processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    )
    inputs = inputs.to(self.model.device)

    # Inference: Generation of the output
    if "max_new_tokens" not in kwargs:
        kwargs["max_new_tokens"] = 512

    generated_ids = self.model.generate(**inputs, **kwargs)
    generated_ids_trimmed = [
        out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]
    class_label = self.processor.batch_decode(
        generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )[0]

    try:
        class_label = class_label.replace('```', '').replace('json', '')
        class_label = json.loads(class_label)
        class_label = class_label["class_label"].lower()
    except Exception as e:
        logger.error(f"Image classification parsing failed trying to parse using another logic.")
        potential_class_labels = [label.lower() for label in class_label.split() if label in self.classification_labels]
        class_label = potential_class_labels[0] if potential_class_labels else "ERROR"

    return class_label

generate(image_paths, **kwargs)

Generates class label for a list of images.

Parameters:

Name Type Description Default
image_paths List[str]

List of image paths to generate class labels for.

required
**kwargs

Additional arguments to pass to the method for custom pipeline interactions. To control generation parameters for the model.

{}

Returns:

Type Description
List[Dict]

List[Dict]: List of class labels, validation reasoning and confidence scores for each image.

Source code in swiftannotate/image/classification/qwen.py
def generate(self, image_paths: List[str], **kwargs) -> List[Dict]:
    """
    Generates class label for a list of images. 

    Args:
        image_paths (List[str]): List of image paths to generate class labels for.
        **kwargs: Additional arguments to pass to the method for custom pipeline interactions. To control generation parameters for the model.

    Returns:
        List[Dict]: List of class labels, validation reasoning and confidence scores for each image.
    """
    results = super().generate(
        image_paths=image_paths, 
        **kwargs
    )

    return results

validate(image, class_label, **kwargs)

Validates the class label generated for the image.

Parameters:

Name Type Description Default
image str

Base64 encoded image.

required
class_label str

Class Label generated for the image.

required

Returns:

Type Description
Tuple[str, float]

Tuple[str, float]: Validation reasoning and confidence score for the class label.

Source code in swiftannotate/image/classification/qwen.py
def validate(self, image: str, class_label: str, **kwargs) -> Tuple[str, float]:
    """
    Validates the class label generated for the image.

    Args:
        image (str): Base64 encoded image.
        class_label (str): Class Label generated for the image.

    Returns:
        Tuple[str, float]: Validation reasoning and confidence score for the class label.
    """
    messages = [
        {"role": "system", "content": self.validation_prompt},
        {
            "role": "user", 
            "content": [
                {
                    "type": "image", 
                    "image": f"data:image;base64,{image}",
                    "resized_height": self.resize_height,
                    "resized_width": self.resize_width,
                },
                {"type": "text", "text": class_label},
                {
                    "type": "text", 
                    "text": """
                    Validate the class label generated for the given image. 
                    Return output as a JSON object with keys as 'validation_reasoning' and 'confidence'.
                    """
                },
            ],
        },
    ]

    text = self.processor.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )

    image_inputs, video_inputs = process_vision_info(messages)
    inputs = self.processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    )
    inputs = inputs.to(self.model.device)

    # Inference: Generation of the output
    if "max_new_tokens" not in kwargs:
        kwargs["max_new_tokens"] = 512

    generated_ids = self.model.generate(**inputs, **kwargs)
    generated_ids_trimmed = [
        out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]
    validation_output = self.processor.batch_decode(
        generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )[0]

    # TODO: Need a better way to parse the output
    try:
        validation_output = validation_output.replace('```', '').replace('json', '')
        validation_output = json.loads(validation_output)
        validation_reasoning = validation_output["validation_reasoning"]
        confidence = validation_output["confidence"]
    except Exception as e:
        logger.error(f"Image class label validation parsing failed trying to parse using another logic.")

        number_str  = ''.join((ch if ch in '0123456789.-e' else ' ') for ch in validation_output)
        number_str = [i for i in number_str.split() if i.isalnum()]
        potential_confidence_scores = [float(i) for i in number_str if float(i) >= 0 and float(i) <= 1]
        confidence = max(potential_confidence_scores) if potential_confidence_scores else 0.0
        validation_reasoning = validation_output

    return validation_reasoning, confidence

Object Detection

OwlV2ForObjectDetection

Bases: BaseObjectDetection

Source code in swiftannotate/image/object_detection/owlv2.py
class OwlV2ForObjectDetection(BaseObjectDetection):
    def __init__(
        self,
        model: Owlv2ForObjectDetection,
        processor: Owlv2Processor,
        class_labels: List[str],
        confidence_threshold: float = 0.5,
        validation: bool = False,
        validation_prompt: str | None = None,
        validation_threshold: float | None = None,
        output_file: str | None = None
    ):
        """
        Initialize the OwlV2ObjectDetection class.

        Args:
            model (Owlv2ForObjectDetection):
                OwlV2 Object Detection model from Transformers.
            processor (Owlv2Processor): 
                OwlV2 Processor for Object Detection.
            class_labels (List[str]): 
                List of class labels.
            confidence_threshold (float, optional): 
                Minimum confidence threshold for object detection. 
                Defaults to 0.5.
            validation (bool, optional): 
                Whether to validate annotations from OwlV2. 
                Defaults to False.
            validation_prompt (str | None, optional): 
                Prompt to validate annotations. 
                Defaults to None.
            validation_threshold (float | None, optional): 
                Threshold score for annotation validation. 
                Defaults to None.
            output_file (str | None, optional): 
                Path to save results.
                If None, results are not saved. Defaults to None.

        Raises:
            ValueError: If model is not an instance of Owlv2ForObjectDetection.
            ValueError: If processor is not an instance of Owlv2Processor.
        """
        if not isinstance(model, Owlv2ForObjectDetection):
            raise ValueError("Model must be an instance of Owlv2ForObjectDetection")
        if not isinstance(processor, Owlv2Processor):
            raise ValueError("Processor must be an instance of Owlv2Processor")

        self.processor = processor
        self.model = model
        self.model.eval()

        super().__init__(
            class_labels,
            confidence_threshold,
            validation,
            validation_prompt,
            validation_threshold,
            output_file
        )

    def annotate(self, image: Image.Image) -> List[dict]:
        """
        Annotate an image with object detection labels

        Args:
            image (Image.Image): Image to be annotated.

        Returns:
            List[dict]: List of dictionaries containing the confidence scores, bounding box coordinates and class labels.
        """
        inputs = self.processor(text=self.class_labels, images=image, return_tensors="pt").to(self.model.device)

        with torch.no_grad():
            outputs = self.model(**inputs)

        target_sizes = torch.Tensor([image.size[::-1]])
        results = self.processor.post_process_object_detection(
            outputs=outputs, 
            target_sizes=target_sizes, 
            threshold=self.confidence_threshold
        )
        return [{k: v.cpu().tolist() for k, v in prediction.items()} for prediction in results]

    def validate(self, image: Image.Image, annotations: List[dict]) -> Tuple:
        """
        Validate the annotations for an image with object detection labels.

        Currently, there is no validation method available for Object Detection.

        # TODO: Idea is to do some sort of object extraction using annotations and ask VLM to validate the extracted objects.
        # TODO: Need to figure out a way to use the VLM output for improving annotations.

        Args:
            image (Image.Image): Image to be validated.
            annotations (List[dict]): List of dictionaries containing the confidence scores, bounding box coordinates and class labels.

        Raises:
            NotImplementedError: _description_
        """
        raise NotImplementedError("No validation method available for Object Detection yet")

    def generate(self, image_paths: List[str]) -> List[dict]:
        """
        Generate annotations for a list of image paths.

        Args:
            image_paths (List[str]): List of image paths.

        Returns:
            List[dict]: List of dictionaries containing the confidence scores, bounding box coordinates and class labels.
        """
        results = super().generate(
            image_paths
        )

        return results

__init__(model, processor, class_labels, confidence_threshold=0.5, validation=False, validation_prompt=None, validation_threshold=None, output_file=None)

Initialize the OwlV2ObjectDetection class.

Parameters:

Name Type Description Default
model Owlv2ForObjectDetection

OwlV2 Object Detection model from Transformers.

required
processor Owlv2Processor

OwlV2 Processor for Object Detection.

required
class_labels List[str]

List of class labels.

required
confidence_threshold float

Minimum confidence threshold for object detection. Defaults to 0.5.

0.5
validation bool

Whether to validate annotations from OwlV2. Defaults to False.

False
validation_prompt str | None

Prompt to validate annotations. Defaults to None.

None
validation_threshold float | None

Threshold score for annotation validation. Defaults to None.

None
output_file str | None

Path to save results. If None, results are not saved. Defaults to None.

None

Raises:

Type Description
ValueError

If model is not an instance of Owlv2ForObjectDetection.

ValueError

If processor is not an instance of Owlv2Processor.

Source code in swiftannotate/image/object_detection/owlv2.py
def __init__(
    self,
    model: Owlv2ForObjectDetection,
    processor: Owlv2Processor,
    class_labels: List[str],
    confidence_threshold: float = 0.5,
    validation: bool = False,
    validation_prompt: str | None = None,
    validation_threshold: float | None = None,
    output_file: str | None = None
):
    """
    Initialize the OwlV2ObjectDetection class.

    Args:
        model (Owlv2ForObjectDetection):
            OwlV2 Object Detection model from Transformers.
        processor (Owlv2Processor): 
            OwlV2 Processor for Object Detection.
        class_labels (List[str]): 
            List of class labels.
        confidence_threshold (float, optional): 
            Minimum confidence threshold for object detection. 
            Defaults to 0.5.
        validation (bool, optional): 
            Whether to validate annotations from OwlV2. 
            Defaults to False.
        validation_prompt (str | None, optional): 
            Prompt to validate annotations. 
            Defaults to None.
        validation_threshold (float | None, optional): 
            Threshold score for annotation validation. 
            Defaults to None.
        output_file (str | None, optional): 
            Path to save results.
            If None, results are not saved. Defaults to None.

    Raises:
        ValueError: If model is not an instance of Owlv2ForObjectDetection.
        ValueError: If processor is not an instance of Owlv2Processor.
    """
    if not isinstance(model, Owlv2ForObjectDetection):
        raise ValueError("Model must be an instance of Owlv2ForObjectDetection")
    if not isinstance(processor, Owlv2Processor):
        raise ValueError("Processor must be an instance of Owlv2Processor")

    self.processor = processor
    self.model = model
    self.model.eval()

    super().__init__(
        class_labels,
        confidence_threshold,
        validation,
        validation_prompt,
        validation_threshold,
        output_file
    )

annotate(image)

Annotate an image with object detection labels

Parameters:

Name Type Description Default
image Image

Image to be annotated.

required

Returns:

Type Description
List[dict]

List[dict]: List of dictionaries containing the confidence scores, bounding box coordinates and class labels.

Source code in swiftannotate/image/object_detection/owlv2.py
def annotate(self, image: Image.Image) -> List[dict]:
    """
    Annotate an image with object detection labels

    Args:
        image (Image.Image): Image to be annotated.

    Returns:
        List[dict]: List of dictionaries containing the confidence scores, bounding box coordinates and class labels.
    """
    inputs = self.processor(text=self.class_labels, images=image, return_tensors="pt").to(self.model.device)

    with torch.no_grad():
        outputs = self.model(**inputs)

    target_sizes = torch.Tensor([image.size[::-1]])
    results = self.processor.post_process_object_detection(
        outputs=outputs, 
        target_sizes=target_sizes, 
        threshold=self.confidence_threshold
    )
    return [{k: v.cpu().tolist() for k, v in prediction.items()} for prediction in results]

generate(image_paths)

Generate annotations for a list of image paths.

Parameters:

Name Type Description Default
image_paths List[str]

List of image paths.

required

Returns:

Type Description
List[dict]

List[dict]: List of dictionaries containing the confidence scores, bounding box coordinates and class labels.

Source code in swiftannotate/image/object_detection/owlv2.py
def generate(self, image_paths: List[str]) -> List[dict]:
    """
    Generate annotations for a list of image paths.

    Args:
        image_paths (List[str]): List of image paths.

    Returns:
        List[dict]: List of dictionaries containing the confidence scores, bounding box coordinates and class labels.
    """
    results = super().generate(
        image_paths
    )

    return results

validate(image, annotations)

Validate the annotations for an image with object detection labels.

Currently, there is no validation method available for Object Detection.

TODO: Idea is to do some sort of object extraction using annotations and ask VLM to validate the extracted objects.

TODO: Need to figure out a way to use the VLM output for improving annotations.

Parameters:

Name Type Description Default
image Image

Image to be validated.

required
annotations List[dict]

List of dictionaries containing the confidence scores, bounding box coordinates and class labels.

required

Raises:

Type Description
NotImplementedError

description

Source code in swiftannotate/image/object_detection/owlv2.py
def validate(self, image: Image.Image, annotations: List[dict]) -> Tuple:
    """
    Validate the annotations for an image with object detection labels.

    Currently, there is no validation method available for Object Detection.

    # TODO: Idea is to do some sort of object extraction using annotations and ask VLM to validate the extracted objects.
    # TODO: Need to figure out a way to use the VLM output for improving annotations.

    Args:
        image (Image.Image): Image to be validated.
        annotations (List[dict]): List of dictionaries containing the confidence scores, bounding box coordinates and class labels.

    Raises:
        NotImplementedError: _description_
    """
    raise NotImplementedError("No validation method available for Object Detection yet")