Skip to content

API Reference

BatchInferer

A class to manage batch inference jobs using AWS Bedrock.

This class handles the creation, monitoring, and retrieval of batch inference jobs for large-scale model invocations using AWS Bedrock service.

Parameters:

Name Type Description Default
model_name str

The name/ID of the AWS Bedrock model to use

required
bucket_name str

The S3 bucket name for storing input/output data

required
region str

The region to run the batch inference job in.

required
job_name str

A unique name for the batch inference job

required
role_arn str

The AWS IAM role ARN with necessary permissions

required
time_out_duration_hours int

Maximum job runtime in hours. Defaults to 24.

24
session session

A boto3 session to be used for calls to AWS, If one if not provided a new one will be created

None

Attributes:

Name Type Description
job_arn str

The ARN of the created batch inference job

results List[dict]

The results of the batch inference job. Available after job completion.

manifest Manifest

Job execution statistics. Available after job completion.

job_status str

Current status of the batch job. One of VALID_FINISHED_STATUSES.

Source code in src/llmbo/batch_inferer.py
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
class BatchInferer:
    """A class to manage batch inference jobs using AWS Bedrock.

    This class handles the creation, monitoring, and retrieval of batch inference jobs
    for large-scale model invocations using AWS Bedrock service.

    Args:
        model_name (str): The name/ID of the AWS Bedrock model to use
        bucket_name (str): The S3 bucket name for storing input/output data
        region (str): The region to run the batch inference job in.
        job_name (str): A unique name for the batch inference job
        role_arn (str): The AWS IAM role ARN with necessary permissions
        time_out_duration_hours (int, optional): Maximum job runtime in hours.
            Defaults to 24.
        session (boto3.session, optional): A boto3 session to be used for calls to AWS,
            If one if not provided a new one will be  created

    Attributes:
        job_arn (str): The ARN of the created batch inference job
        results (List[dict]): The results of the batch inference job.
            Available after job completion.
        manifest (Manifest): Job execution statistics. Available after job completion.
        job_status (str): Current status of the batch job.
            One of VALID_FINISHED_STATUSES.
    """

    logger = logging.getLogger(f"{__name__}.BatchInferer")

    def __init__(
        self,
        model_name: str,  # this should be an enum...
        bucket_name: str,
        region: str,
        job_name: str,
        role_arn: str,
        time_out_duration_hours: int = 24,
        session: boto3.Session | None = None,
    ):
        """Initialize a BatchInferer for AWS Bedrock batch processing.

        Creates a configured batch inference manager that handles the end-to-end process
        of submitting and managing batch jobs on AWS Bedrock.

        Args:
            model_name (str): The AWS Bedrock model identifier
                (e.g., 'anthropic.claude-3-haiku-20240307-v1:0')
            bucket_name (str): Name of the S3 bucket for storing job inputs and outputs
            region (str): The region containing the llm to call, must match the bucket
            job_name (str): Unique identifier for this batch job. Used in file naming.
            role_arn (str): AWS IAM role ARN with permissions for Bedrock and S3 access
            time_out_duration_hours (int, optional): Maximum runtime for the batch job.
                Defaults to 24 hours.
            session (boto3.session, optional): A boto3 session to be used for AWS calls,
                If one if not provided a new one will be created

        Raises:
            KeyError: If AWS_PROFILE environment variable is not set
            ValueError: If the provided role_arn doesn't exist or is invalid

        Example:
        ```python
            >>> bi = BatchInferer(
                    model_name="anthropic.claude-3-haiku-20240307-v1:0",
                    bucket_name="my-inference-bucket",
                    job_name="batch-job-2024-01-01",
                    role_arn="arn:aws:iam::123456789012:role/BedrockBatchRole"
                )
        ```

        Note:
            - Requires valid AWS credentials and configuration
            - The S3 bucket must exist and be accessible via the provided role
            - Job name will be used to create unique file names for inputs and outputs
        """
        self.logger.info("Intialising BatchInferer")
        # model parameters
        self.model_name = model_name
        self.adapter = self._get_adapter(model_name)

        self.time_out_duration_hours = time_out_duration_hours

        self.session: boto3.Session = session or boto3.Session()

        # file/bucket parameters
        self._check_bucket(bucket_name, region)
        self.bucket_name = bucket_name
        self.bucket_uri = "s3://" + bucket_name
        self.job_name = job_name or "batch_inference_" + str(uuid4())[:6]
        self.file_name = job_name + ".jsonl"
        self.output_file_name = None
        self.manifest_file_name = None

        self.check_for_profile()
        self._check_arn(role_arn)
        self.role_arn = role_arn
        self.region = region

        self.client: boto3.client = self.session.client("bedrock", region_name=region)

        # internal state - created by the class later.
        self.job_arn = None
        self.job_status = None
        self.results = None
        self.manifest = None
        self.requests = None

        self.logger.info("Initialized BatchInferer")

    @property
    def unique_id_from_arn(self) -> str:
        """Retrieves the id from the job ARN.

        Raises:
            ValueError: if no job ARN has been set

        Returns:
            str: a unique id portion of the job ARN
        """
        if not self.job_arn:
            self.logger.error("Job ARN not set")
            raise ValueError("Job ARN not set")
        return self.job_arn.split("/")[-1]

    def check_for_profile(self) -> None:
        """Checks if a profile has been set.

        Raises:
            KeyError: If AWS_PROFILE does not exist in the env.
        """
        if not os.getenv("AWS_PROFILE"):
            self.logger.error("AWS_PROFILE environment variable not set")
            raise KeyError("AWS_PROFILE environment variable not set")

    @staticmethod
    def _read_jsonl(file_path):
        data = []
        with open(file_path) as file:
            for line in file:
                data.append(json.loads(line.strip()))
        return data

    def _get_bucket_location(self, bucket_name: str) -> str | None:
        """Get the location of the s3 bucket.

        Args:
            bucket_name (str): the name of a bucket

        Raises:
            ValueError: If the bucket is not accessible

        Returns:
            str: a region, e.g. "eu-west-2"
        """
        try:
            s3_client = self.session.client("s3")
            response = s3_client.get_bucket_location(Bucket=bucket_name)

            if response:
                region = response["LocationConstraint"]
                # aws returns None if the region is us-east-1 otherwise it returns the
                # region
                return region if region else "us-east-1"
        except ClientError as e:
            self.logger.error(f"Bucket {bucket_name} is not accessible: {e}")
            raise ValueError(f"Bucket {bucket_name} is not accessible") from e

    def _check_bucket(self, bucket_name: str, region: str) -> None:
        """Validate if the bucket_name provided exists.

        Args:
            bucket_name (str): the name of a bucket
            region (str): the name of a region

        Raises:
            ValueError: If the bucket is not accessible
            ValueError: If the bucket is not in the same region as the LLM.
        """
        try:
            s3_client = self.session.client("s3")
            s3_client.head_bucket(Bucket=bucket_name)
        except ClientError as e:
            self.logger.error(f"Bucket {bucket_name} is not accessible: {e}")
            raise ValueError(f"Bucket {bucket_name} is not accessible") from e

        if (bucket_region := self._get_bucket_location(bucket_name)) != region:
            self.logger.error(
                f"Bucket {bucket_name} is not located in the same region [{region}] "
                f"as the llm [{bucket_region}]"
            )
            raise ValueError(
                f"Bucket {bucket_name} is not located in the same region [{region}] "
                f"as the llm [{bucket_region}]"
            )

    def _check_arn(self, role_arn: str) -> bool:
        """Validate if an IAM role exists and is accessible.

        Attempts to retrieve the IAM role using the provided ARN to verify its
        existence and accessibility.

        Args:
            role_arn (str): The AWS ARN of the IAM role to check.
                Format: 'arn:aws:iam::<account-id>:role/<role-name>'

        Returns:
            bool: True if the role exists and is accessible.

        Raises:
            ValueError: If the role does not exist.
        ClientError: If there are AWS API issues unrelated to role existence.
        """
        if not role_arn.startswith("arn:aws:iam::"):
            self.logger.error("Invalid ARN format")
            raise ValueError("Invalid ARN format")

        # Extract the role name from the ARN
        role_name = role_arn.split("/")[-1]

        iam_client = self.session.client("iam")

        try:
            # Try to get the role
            iam_client.get_role(RoleName=role_name)
            self.logger.info(f"Role '{role_name}' exists.")
            return True
        except ClientError as e:
            if e.response["Error"]["Code"] == "NoSuchEntity":
                self.logger.error(f"Role '{role_name}' does not exist.")
                raise ValueError(f"Role '{role_name}' does not exist.") from e
            else:
                raise e

    def _get_adapter(self, model_name):
        # Get the appropriate adapter for this model
        try:
            adapter = ModelAdapterRegistry.get_adapter(model_name)
            self.logger.info(f"Using {adapter.__name__} for model {model_name}")
            return adapter
        except ValueError as e:
            self.logger.error(f"No adapter available for model {model_name}")
            raise ValueError(f"No adapter available for model {model_name}") from e

    def prepare_requests(self, inputs: dict[str, ModelInput]) -> None:
        """Prepare batch inference requests from a dictionary of model inputs.

        Formats model inputs into the required JSONL structure for AWS Bedrock
        batch processing. Each request is formatted as:
            {
                "recordId": str,
                "modelInput": dict
            }

        Args:
            inputs (Dict[str, ModelInput]): Dictionary mapping record IDs to their corresponding
                ModelInput configurations. The record IDs will be used to track results.

        Raises:
            ValueError: If len(inputs) < 100, as AWS Bedrock requires minimum batch size of 100

        Example:
            >>> inputs = {
            ...     "001": ModelInput(
            ...         messages=[{"role": "user", "content": "Hello"}],
            ...         temperature=0.7
            ...     ),
            ...     "002": ModelInput(
            ...         messages=[{"role": "user", "content": "Hi"}],
            ...         temperature=0.7
            ...     )
            ... }
            >>> bi.prepare_requests(inputs)

        Note:
            - This method must be called before push_requests_to_s3()
            - The prepared requests are stored in self.requests
            - Each ModelInput is converted to a dict using its to_dict() method
        """
        # TODO: Should I copy these inputs so I dont modify them.
        self.logger.info(f"Preparing {len(inputs)} requests")
        self._check_input_length(inputs)
        self.logger.info("Adding model specific parameters to model_input")
        for id, model_input in inputs.items():
            inputs[id] = self.adapter.prepare_model_input(model_input)

        self.requests = self._to_requests(inputs)

    def _to_requests(self, inputs):
        self.logger.info("Converting to dict")
        return [
            {
                "recordId": id,
                "modelInput": model_input.to_dict(),
            }
            for id, model_input in inputs.items()
        ]

    def _check_input_length(self, inputs):
        if inputs is None:
            self.logger.error("Minimum Batch Size is 100, None supplied")
            raise ValueError("Minimum Batch Size is 100, None supplied")

        if len(inputs) < 100:
            self.logger.error(f"Minimum Batch Size is 100, {len(inputs)} given.")
            raise ValueError(f"Minimum Batch Size is 100, {len(inputs)} given.")

    def _write_requests_locally(self) -> None:
        """Write batch inference requests to a local JSONL file.

        Creates or overwrites a local JSONL file containing the prepared inference
        requests. Each line contains a JSON object with recordId and modelInput.

        Raises:
            IOError: If unable to write to the file
            AttributeError: If called before prepare_requests()

        Note:
            - File is named according to self.file_name
            - Internal method used by push_requests_to_s3()
            - Will overwrite existing files with the same name
        """
        self.logger.info(f"Writing {len(self.requests)} requests to {self.file_name}")
        with open(self.file_name, "w") as file:
            for record in self.requests:
                file.write(json.dumps(record) + "\n")

    def push_requests_to_s3(self) -> dict[str, Any]:
        """Upload batch inference requests to S3.

        Writes the prepared requests to a local JSONL file and uploads it to the
        configured S3 bucket in the 'input/' prefix.

        Returns:
            dict: The S3 upload response from boto3

        Raises:
            IOError: If local file operations fail
            ClientError: If S3 upload fails
            AttributeError: If called before prepare_requests()

        Note:
            - Creates/overwrites files both locally and in S3
            - S3 path: {bucket_name}/input/{job_name}.jsonl
            - Sets Content-Type to 'application/json'
        """
        # do I want to write this file locally? - maybe stream it or write it to
        # temp file instead
        self._write_requests_locally()
        s3_client = self.session.client("s3")
        self.logger.info(f"Pushing {len(self.requests)} requests to {self.bucket_name}")
        response = s3_client.upload_file(
            Filename=self.file_name,
            Bucket=self.bucket_name,
            Key=f"input/{self.file_name}",
            ExtraArgs={"ContentType": "application/json"},
        )
        return response

    def create(self) -> dict[str, Any]:
        """Create a new batch inference job in AWS Bedrock.

        Initializes a new model invocation job using the configured parameters
        and uploaded input data.

        Returns:
            dict: The complete response from the create_model_invocation_job API call

        Raises:
            RuntimeError: If job creation fails
            ClientError: For AWS API errors
            ValueError: If required configurations are missing

        Note:
            - Sets self.job_arn on successful creation
            - Input data must be uploaded to S3 before calling this method
            - Job will timeout after self.time_out_duration_hours
        """
        if self.requests:
            self.logger.info(f"Creating job {self.job_name}")
            response = self.client.create_model_invocation_job(
                jobName=self.job_name,
                roleArn=self.role_arn,
                clientRequestToken="string",
                modelId=self.model_name,
                inputDataConfig={
                    "s3InputDataConfig": {
                        "s3InputFormat": "JSONL",
                        "s3Uri": f"{self.bucket_uri}/input/{self.file_name}",
                    }
                },
                outputDataConfig={
                    "s3OutputDataConfig": {
                        "s3Uri": f"{self.bucket_uri}/output/",
                    }
                },
                timeoutDurationInHours=self.time_out_duration_hours,
                tags=[{"key": "bedrock_batch_inference", "value": self.job_name}],
            )

            if response:
                response_status = response["ResponseMetadata"]["HTTPStatusCode"]
                if response_status == 200:
                    self.logger.info(f"Job {self.job_name} created successfully")
                    self.logger.info(f"Assigned jobArn: {response['jobArn']}")
                    self.job_arn = response["jobArn"]
                    return response
                else:
                    self.logger.error(
                        f"There was an error creating the job {self.job_name},"
                        " non 200 response from bedrock"
                    )
                    raise RuntimeError(
                        f"There was an error creating the job {self.job_name},"
                        " non 200 response from bedrock"
                    )
            else:
                self.logger.error(
                    "There was an error creating the job, no response from bedrock"
                )
                raise RuntimeError(
                    "There was an error creating the job, no response from bedrock"
                )
        else:
            self.logger.error("There were no prepared requests")
            raise AttributeError("There were no prepared requests")

    def download_results(self) -> None:
        """Download batch inference results from S3.

        Retrieves both the results and manifest files from S3 once the job
        has completed. Files are downloaded to:
            - {job_name}_out.jsonl: Contains model outputs
            - {job_name}_manifest.jsonl: Contains job statistics

        Raises:
            ClientError: For S3 download failures
            ValueError: If job hasn't completed or job_arn isn't set

        Note:
            - Only downloads if job status is in VALID_FINISHED_STATUSES
            - Files are downloaded to current working directory
            - Existing files will be overwritten
            - Call check_complete() first to ensure job is finished
        """
        if self.check_complete() in VALID_FINISHED_STATUSES:
            file_name_, ext = os.path.splitext(self.file_name)
            self.output_file_name = f"{file_name_}_out{ext}"
            self.manifest_file_name = f"{file_name_}_manifest{ext}"
            self.logger.info(
                f"Job:{self.job_arn} Complete. Downloading results from {self.bucket_name}"
            )
            s3_client = self.session.client("s3")
            s3_client.download_file(
                Bucket=self.bucket_name,
                Key=f"output/{self.unique_id_from_arn}/{self.file_name}.out",
                Filename=self.output_file_name,
            )
            self.logger.info(f"Downloaded results file to {self.output_file_name}")

            s3_client.download_file(
                Bucket=self.bucket_name,
                Key=f"output/{self.unique_id_from_arn}/manifest.json.out",
                Filename=self.manifest_file_name,
            )
            self.logger.info(f"Downloaded manifest file to {self.manifest_file_name}")
        else:
            self.logger.info(
                f"Job:{self.job_arn} was not marked one of {VALID_FINISHED_STATUSES}, could not download."
            )

    def load_results(self) -> None:
        """Load batch inference results and manifest from local files.

        Reads and parses the output files downloaded from S3, populating:
            - self.results: List of inference results from the output JSONL file
            - self.manifest: Statistics about the job execution (total records, success/error counts, etc.)

        The method expects two files to exist locally:
            - {job_name}_out.jsonl: Contains the model outputs
            - {job_name}_manifest.jsonl: Contains execution statistics

        Raises:
            FileExistsError: If either the results or manifest files are not found locally

        Note:
            - Must call download_results() before calling this method
            - The manifest provides useful metrics like success rate and token counts
        """
        if os.path.isfile(self.output_file_name) and os.path.isfile(
            self.manifest_file_name
        ):
            self.results = self._read_jsonl(self.output_file_name)
            self.manifest = Manifest(**self._read_jsonl(self.manifest_file_name)[0])
        else:
            self.logger.error(
                "Result files do not exist, you may need to call .download_results() first."
            )
            raise FileExistsError(
                "Result files do not exist, you may need to call .download_results() first."
            )

    def cancel_batch(self) -> None:
        """Cancel a running batch inference job.

        Attempts to stop the currently running batch inference job identified by self.job_arn.

        Returns:
            None

        Raises:
            RuntimeError: If the job cancellation request fails
            ValueError: If no job_arn is set (i.e., no job has been created)
        """
        if not self.job_arn:
            self.logger.error("No job_arn set - no job to cancel")
            raise ValueError("No job_arn set - no job to cancel")

        response = self.client.stop_model_invocation_job(jobIdentifier=self.job_arn)

        if response["ResponseMetadata"]["HTTPStatusCode"] == 200:
            self.logger.info(
                f"Job {self.job_name} with id={self.job_arn} was cancelled"
            )
            self.job_status = "Stopped"
        else:
            self.logger.error(
                f"Failed to cancel job {self.job_name}. Status: {response['ResponseMetadata']['HTTPStatusCode']}"
            )
            raise RuntimeError(f"Failed to cancel job {self.job_name}")

    def check_complete(self) -> str | None:
        """Check if the batch inference job has completed.

        Returns:
        str | None: The job status if the job has finished (one of 'Completed', 'Failed',
            'Stopped', or 'Expired'), or None if the job is still in progress.
        """
        if self.job_status not in VALID_FINISHED_STATUSES:
            self.logger.info(f"Checking status of job {self.job_arn}")
            response = self.client.get_model_invocation_job(jobIdentifier=self.job_arn)

            self.job_status = response["status"]
            self.logger.info(f"Job status is {self.job_status}")

            if self.job_status in VALID_FINISHED_STATUSES:
                return self.job_status
            return None
        else:
            self.logger.info(f"Job {self.job_arn} is already {self.job_status}")
            return self.job_status

    def poll_progress(self, poll_interval_seconds: int = 60) -> bool:
        """Polls the progress of a job.

        Args:
            poll_interval_seconds (int, optional): Number of seconds between checks. Defaults to 60.

        Returns:
            bool: True if job is complete.
        """
        self.logger.info(f"Polling for progress every {poll_interval_seconds} seconds")
        while not self.check_complete():
            time.sleep(poll_interval_seconds)
        return True

    def auto(self, inputs: dict[str, ModelInput], poll_time_secs: int = 60) -> dict:
        """Execute the complete batch inference workflow automatically.

        This method combines the preparation, execution, monitoring, and result retrieval
        steps into a single operation.

        Args:
            inputs (Dict[str, ModelInput]): Dictionary of record IDs mapped to their ModelInput configurations
            poll_time_secs (int, optional): How often to poll for model progress. Defaults to 60.

        Returns:
            List[Dict]: The results from the batch inference job
        """
        self.prepare_requests(inputs)
        self.push_requests_to_s3()
        self.create()
        self.poll_progress(poll_time_secs)
        self.download_results()
        self.load_results()
        return self.results

    @classmethod
    def recover_details_from_job_arn(
        cls, job_arn: str, region: str, session: boto3.Session | None = None
    ) -> "BatchInferer":
        """Recover a BatchInferer instance from an existing job ARN.

        Used to reconstruct a BatchInferer object when the original Python process
        has terminated but the AWS job is still running or complete.

        Args:
            job_arn: (str) The AWS ARN of the existing batch inference job
            region: (str) the region where the job was scheduled
            session (boto3.session, optional): A boto3 session to be used for calls to AWS,
                    If one if not provided a new one will be  created

        Returns:
            BatchInferer: A configured instance with the job's details

        Raises:
            ValueError: If the job cannot be found or response is invalid

        Example:
            >>> job_arn = "arn:aws:bedrock:region:account:job/xyz123"
            >>> bi = BatchInferer.recover_details_from_job_arn(job_arn)
            >>> bi.check_complete()
            'Completed'
        """
        cls.logger.info(f"Attempting to Recover BatchInferer from {job_arn}")
        session = session or boto3.Session()
        response = cls.check_for_existing_job(job_arn, region, session)

        try:
            # Extract required parameters from response
            job_name = response["jobName"]
            model_id = response["modelId"]
            bucket_name = response["inputDataConfig"]["s3InputDataConfig"][
                "s3Uri"
            ].split("/")[2]
            role_arn = response["roleArn"]

            # Validate required files exist
            input_file = f"{job_name}.jsonl"
            if not os.path.exists(input_file):
                cls.logger.error(f"Required input file not found: {input_file}")
                raise FileNotFoundError(f"Required input file not found: {input_file}")

            requests = cls._read_jsonl(input_file)

            bi = cls(
                model_name=model_id,
                job_name=job_name,
                region=region,
                bucket_name=bucket_name,
                role_arn=role_arn,
                session=session,
            )
            bi.job_arn = job_arn
            bi.requests = requests
            bi.job_status = response["status"]

            return bi

        except (KeyError, IndexError) as e:
            cls.logger.error(f"Invalid job response format: {str(e)}")
            raise ValueError(f"Invalid job response format: {str(e)}") from e
        except Exception as e:
            cls.logger.error(f"Failed to recover job details: {str(e)}")
            raise RuntimeError(f"Failed to recover job details: {str(e)}") from e

    @classmethod
    def check_for_existing_job(
        cls, job_arn, region, session: boto3.Session | None = None
    ) -> dict[str, Any]:
        """Check if a job exists and return its details.

        Args:
            job_arn (str): The AWS ARN of the job to check
            region (str): The AWS region where the job was created
            session (boto3.Session, optional): A boto3 session to be used for AWS API calls.
                                           If not provided, a new session will be created.

        Returns:
            Dict[str, Any]: The job details from AWS Bedrock

        Raises:
            ValueError: If the job ARN is invalid or the job is not found
            RuntimeError: For other AWS API errors
        """
        if not job_arn.startswith("arn:aws:bedrock:"):
            cls.logger.error(f"Invalid Bedrock ARN format: {job_arn}")
            raise ValueError(f"Invalid Bedrock ARN format: {job_arn}")
        session = session or boto3.Session()
        client = session.client("bedrock", region_name=region)

        try:
            response = client.get_model_invocation_job(jobIdentifier=job_arn)
        except ClientError as e:
            if e.response["Error"]["Code"] == "ResourceNotFoundException":
                cls.logger.error(f"Job not found: {job_arn}")
                raise ValueError(f"Job not found: {job_arn}") from e
            cls.logger.error(f"AWS API error: {str(e)}")
            raise RuntimeError(f"AWS API error: {str(e)}") from e

        if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
            cls.logger.error(
                f"Unexpected response status: {response['ResponseMetadata']['HTTPStatusCode']}"
            )
            raise RuntimeError(
                f"Unexpected response status: {response['ResponseMetadata']['HTTPStatusCode']}"
            )

        return response

unique_id_from_arn property

Retrieves the id from the job ARN.

Raises:

Type Description
ValueError

if no job ARN has been set

Returns:

Name Type Description
str str

a unique id portion of the job ARN

__init__(model_name, bucket_name, region, job_name, role_arn, time_out_duration_hours=24, session=None)

Initialize a BatchInferer for AWS Bedrock batch processing.

Creates a configured batch inference manager that handles the end-to-end process of submitting and managing batch jobs on AWS Bedrock.

Parameters:

Name Type Description Default
model_name str

The AWS Bedrock model identifier (e.g., 'anthropic.claude-3-haiku-20240307-v1:0')

required
bucket_name str

Name of the S3 bucket for storing job inputs and outputs

required
region str

The region containing the llm to call, must match the bucket

required
job_name str

Unique identifier for this batch job. Used in file naming.

required
role_arn str

AWS IAM role ARN with permissions for Bedrock and S3 access

required
time_out_duration_hours int

Maximum runtime for the batch job. Defaults to 24 hours.

24
session session

A boto3 session to be used for AWS calls, If one if not provided a new one will be created

None

Raises:

Type Description
KeyError

If AWS_PROFILE environment variable is not set

ValueError

If the provided role_arn doesn't exist or is invalid

Example:

    >>> bi = BatchInferer(
            model_name="anthropic.claude-3-haiku-20240307-v1:0",
            bucket_name="my-inference-bucket",
            job_name="batch-job-2024-01-01",
            role_arn="arn:aws:iam::123456789012:role/BedrockBatchRole"
        )

Note
  • Requires valid AWS credentials and configuration
  • The S3 bucket must exist and be accessible via the provided role
  • Job name will be used to create unique file names for inputs and outputs
Source code in src/llmbo/batch_inferer.py
def __init__(
    self,
    model_name: str,  # this should be an enum...
    bucket_name: str,
    region: str,
    job_name: str,
    role_arn: str,
    time_out_duration_hours: int = 24,
    session: boto3.Session | None = None,
):
    """Initialize a BatchInferer for AWS Bedrock batch processing.

    Creates a configured batch inference manager that handles the end-to-end process
    of submitting and managing batch jobs on AWS Bedrock.

    Args:
        model_name (str): The AWS Bedrock model identifier
            (e.g., 'anthropic.claude-3-haiku-20240307-v1:0')
        bucket_name (str): Name of the S3 bucket for storing job inputs and outputs
        region (str): The region containing the llm to call, must match the bucket
        job_name (str): Unique identifier for this batch job. Used in file naming.
        role_arn (str): AWS IAM role ARN with permissions for Bedrock and S3 access
        time_out_duration_hours (int, optional): Maximum runtime for the batch job.
            Defaults to 24 hours.
        session (boto3.session, optional): A boto3 session to be used for AWS calls,
            If one if not provided a new one will be created

    Raises:
        KeyError: If AWS_PROFILE environment variable is not set
        ValueError: If the provided role_arn doesn't exist or is invalid

    Example:
    ```python
        >>> bi = BatchInferer(
                model_name="anthropic.claude-3-haiku-20240307-v1:0",
                bucket_name="my-inference-bucket",
                job_name="batch-job-2024-01-01",
                role_arn="arn:aws:iam::123456789012:role/BedrockBatchRole"
            )
    ```

    Note:
        - Requires valid AWS credentials and configuration
        - The S3 bucket must exist and be accessible via the provided role
        - Job name will be used to create unique file names for inputs and outputs
    """
    self.logger.info("Intialising BatchInferer")
    # model parameters
    self.model_name = model_name
    self.adapter = self._get_adapter(model_name)

    self.time_out_duration_hours = time_out_duration_hours

    self.session: boto3.Session = session or boto3.Session()

    # file/bucket parameters
    self._check_bucket(bucket_name, region)
    self.bucket_name = bucket_name
    self.bucket_uri = "s3://" + bucket_name
    self.job_name = job_name or "batch_inference_" + str(uuid4())[:6]
    self.file_name = job_name + ".jsonl"
    self.output_file_name = None
    self.manifest_file_name = None

    self.check_for_profile()
    self._check_arn(role_arn)
    self.role_arn = role_arn
    self.region = region

    self.client: boto3.client = self.session.client("bedrock", region_name=region)

    # internal state - created by the class later.
    self.job_arn = None
    self.job_status = None
    self.results = None
    self.manifest = None
    self.requests = None

    self.logger.info("Initialized BatchInferer")

_check_arn(role_arn)

Validate if an IAM role exists and is accessible.

Attempts to retrieve the IAM role using the provided ARN to verify its existence and accessibility.

Parameters:

Name Type Description Default
role_arn str

The AWS ARN of the IAM role to check. Format: 'arn:aws:iam:::role/'

required

Returns:

Name Type Description
bool bool

True if the role exists and is accessible.

Raises:

Type Description
ValueError

If the role does not exist.

ClientError: If there are AWS API issues unrelated to role existence.

Source code in src/llmbo/batch_inferer.py
def _check_arn(self, role_arn: str) -> bool:
    """Validate if an IAM role exists and is accessible.

    Attempts to retrieve the IAM role using the provided ARN to verify its
    existence and accessibility.

    Args:
        role_arn (str): The AWS ARN of the IAM role to check.
            Format: 'arn:aws:iam::<account-id>:role/<role-name>'

    Returns:
        bool: True if the role exists and is accessible.

    Raises:
        ValueError: If the role does not exist.
    ClientError: If there are AWS API issues unrelated to role existence.
    """
    if not role_arn.startswith("arn:aws:iam::"):
        self.logger.error("Invalid ARN format")
        raise ValueError("Invalid ARN format")

    # Extract the role name from the ARN
    role_name = role_arn.split("/")[-1]

    iam_client = self.session.client("iam")

    try:
        # Try to get the role
        iam_client.get_role(RoleName=role_name)
        self.logger.info(f"Role '{role_name}' exists.")
        return True
    except ClientError as e:
        if e.response["Error"]["Code"] == "NoSuchEntity":
            self.logger.error(f"Role '{role_name}' does not exist.")
            raise ValueError(f"Role '{role_name}' does not exist.") from e
        else:
            raise e

_check_bucket(bucket_name, region)

Validate if the bucket_name provided exists.

Parameters:

Name Type Description Default
bucket_name str

the name of a bucket

required
region str

the name of a region

required

Raises:

Type Description
ValueError

If the bucket is not accessible

ValueError

If the bucket is not in the same region as the LLM.

Source code in src/llmbo/batch_inferer.py
def _check_bucket(self, bucket_name: str, region: str) -> None:
    """Validate if the bucket_name provided exists.

    Args:
        bucket_name (str): the name of a bucket
        region (str): the name of a region

    Raises:
        ValueError: If the bucket is not accessible
        ValueError: If the bucket is not in the same region as the LLM.
    """
    try:
        s3_client = self.session.client("s3")
        s3_client.head_bucket(Bucket=bucket_name)
    except ClientError as e:
        self.logger.error(f"Bucket {bucket_name} is not accessible: {e}")
        raise ValueError(f"Bucket {bucket_name} is not accessible") from e

    if (bucket_region := self._get_bucket_location(bucket_name)) != region:
        self.logger.error(
            f"Bucket {bucket_name} is not located in the same region [{region}] "
            f"as the llm [{bucket_region}]"
        )
        raise ValueError(
            f"Bucket {bucket_name} is not located in the same region [{region}] "
            f"as the llm [{bucket_region}]"
        )

_get_bucket_location(bucket_name)

Get the location of the s3 bucket.

Parameters:

Name Type Description Default
bucket_name str

the name of a bucket

required

Raises:

Type Description
ValueError

If the bucket is not accessible

Returns:

Name Type Description
str str | None

a region, e.g. "eu-west-2"

Source code in src/llmbo/batch_inferer.py
def _get_bucket_location(self, bucket_name: str) -> str | None:
    """Get the location of the s3 bucket.

    Args:
        bucket_name (str): the name of a bucket

    Raises:
        ValueError: If the bucket is not accessible

    Returns:
        str: a region, e.g. "eu-west-2"
    """
    try:
        s3_client = self.session.client("s3")
        response = s3_client.get_bucket_location(Bucket=bucket_name)

        if response:
            region = response["LocationConstraint"]
            # aws returns None if the region is us-east-1 otherwise it returns the
            # region
            return region if region else "us-east-1"
    except ClientError as e:
        self.logger.error(f"Bucket {bucket_name} is not accessible: {e}")
        raise ValueError(f"Bucket {bucket_name} is not accessible") from e

_write_requests_locally()

Write batch inference requests to a local JSONL file.

Creates or overwrites a local JSONL file containing the prepared inference requests. Each line contains a JSON object with recordId and modelInput.

Raises:

Type Description
IOError

If unable to write to the file

AttributeError

If called before prepare_requests()

Note
  • File is named according to self.file_name
  • Internal method used by push_requests_to_s3()
  • Will overwrite existing files with the same name
Source code in src/llmbo/batch_inferer.py
def _write_requests_locally(self) -> None:
    """Write batch inference requests to a local JSONL file.

    Creates or overwrites a local JSONL file containing the prepared inference
    requests. Each line contains a JSON object with recordId and modelInput.

    Raises:
        IOError: If unable to write to the file
        AttributeError: If called before prepare_requests()

    Note:
        - File is named according to self.file_name
        - Internal method used by push_requests_to_s3()
        - Will overwrite existing files with the same name
    """
    self.logger.info(f"Writing {len(self.requests)} requests to {self.file_name}")
    with open(self.file_name, "w") as file:
        for record in self.requests:
            file.write(json.dumps(record) + "\n")

auto(inputs, poll_time_secs=60)

Execute the complete batch inference workflow automatically.

This method combines the preparation, execution, monitoring, and result retrieval steps into a single operation.

Parameters:

Name Type Description Default
inputs Dict[str, ModelInput]

Dictionary of record IDs mapped to their ModelInput configurations

required
poll_time_secs int

How often to poll for model progress. Defaults to 60.

60

Returns:

Type Description
dict

List[Dict]: The results from the batch inference job

Source code in src/llmbo/batch_inferer.py
def auto(self, inputs: dict[str, ModelInput], poll_time_secs: int = 60) -> dict:
    """Execute the complete batch inference workflow automatically.

    This method combines the preparation, execution, monitoring, and result retrieval
    steps into a single operation.

    Args:
        inputs (Dict[str, ModelInput]): Dictionary of record IDs mapped to their ModelInput configurations
        poll_time_secs (int, optional): How often to poll for model progress. Defaults to 60.

    Returns:
        List[Dict]: The results from the batch inference job
    """
    self.prepare_requests(inputs)
    self.push_requests_to_s3()
    self.create()
    self.poll_progress(poll_time_secs)
    self.download_results()
    self.load_results()
    return self.results

cancel_batch()

Cancel a running batch inference job.

Attempts to stop the currently running batch inference job identified by self.job_arn.

Returns:

Type Description
None

None

Raises:

Type Description
RuntimeError

If the job cancellation request fails

ValueError

If no job_arn is set (i.e., no job has been created)

Source code in src/llmbo/batch_inferer.py
def cancel_batch(self) -> None:
    """Cancel a running batch inference job.

    Attempts to stop the currently running batch inference job identified by self.job_arn.

    Returns:
        None

    Raises:
        RuntimeError: If the job cancellation request fails
        ValueError: If no job_arn is set (i.e., no job has been created)
    """
    if not self.job_arn:
        self.logger.error("No job_arn set - no job to cancel")
        raise ValueError("No job_arn set - no job to cancel")

    response = self.client.stop_model_invocation_job(jobIdentifier=self.job_arn)

    if response["ResponseMetadata"]["HTTPStatusCode"] == 200:
        self.logger.info(
            f"Job {self.job_name} with id={self.job_arn} was cancelled"
        )
        self.job_status = "Stopped"
    else:
        self.logger.error(
            f"Failed to cancel job {self.job_name}. Status: {response['ResponseMetadata']['HTTPStatusCode']}"
        )
        raise RuntimeError(f"Failed to cancel job {self.job_name}")

check_complete()

Check if the batch inference job has completed.

str | None: The job status if the job has finished (one of 'Completed', 'Failed', 'Stopped', or 'Expired'), or None if the job is still in progress.

Source code in src/llmbo/batch_inferer.py
def check_complete(self) -> str | None:
    """Check if the batch inference job has completed.

    Returns:
    str | None: The job status if the job has finished (one of 'Completed', 'Failed',
        'Stopped', or 'Expired'), or None if the job is still in progress.
    """
    if self.job_status not in VALID_FINISHED_STATUSES:
        self.logger.info(f"Checking status of job {self.job_arn}")
        response = self.client.get_model_invocation_job(jobIdentifier=self.job_arn)

        self.job_status = response["status"]
        self.logger.info(f"Job status is {self.job_status}")

        if self.job_status in VALID_FINISHED_STATUSES:
            return self.job_status
        return None
    else:
        self.logger.info(f"Job {self.job_arn} is already {self.job_status}")
        return self.job_status

check_for_existing_job(job_arn, region, session=None) classmethod

Check if a job exists and return its details.

Parameters:

Name Type Description Default
job_arn str

The AWS ARN of the job to check

required
region str

The AWS region where the job was created

required
session Session

A boto3 session to be used for AWS API calls. If not provided, a new session will be created.

None

Returns:

Type Description
dict[str, Any]

Dict[str, Any]: The job details from AWS Bedrock

Raises:

Type Description
ValueError

If the job ARN is invalid or the job is not found

RuntimeError

For other AWS API errors

Source code in src/llmbo/batch_inferer.py
@classmethod
def check_for_existing_job(
    cls, job_arn, region, session: boto3.Session | None = None
) -> dict[str, Any]:
    """Check if a job exists and return its details.

    Args:
        job_arn (str): The AWS ARN of the job to check
        region (str): The AWS region where the job was created
        session (boto3.Session, optional): A boto3 session to be used for AWS API calls.
                                       If not provided, a new session will be created.

    Returns:
        Dict[str, Any]: The job details from AWS Bedrock

    Raises:
        ValueError: If the job ARN is invalid or the job is not found
        RuntimeError: For other AWS API errors
    """
    if not job_arn.startswith("arn:aws:bedrock:"):
        cls.logger.error(f"Invalid Bedrock ARN format: {job_arn}")
        raise ValueError(f"Invalid Bedrock ARN format: {job_arn}")
    session = session or boto3.Session()
    client = session.client("bedrock", region_name=region)

    try:
        response = client.get_model_invocation_job(jobIdentifier=job_arn)
    except ClientError as e:
        if e.response["Error"]["Code"] == "ResourceNotFoundException":
            cls.logger.error(f"Job not found: {job_arn}")
            raise ValueError(f"Job not found: {job_arn}") from e
        cls.logger.error(f"AWS API error: {str(e)}")
        raise RuntimeError(f"AWS API error: {str(e)}") from e

    if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
        cls.logger.error(
            f"Unexpected response status: {response['ResponseMetadata']['HTTPStatusCode']}"
        )
        raise RuntimeError(
            f"Unexpected response status: {response['ResponseMetadata']['HTTPStatusCode']}"
        )

    return response

check_for_profile()

Checks if a profile has been set.

Raises:

Type Description
KeyError

If AWS_PROFILE does not exist in the env.

Source code in src/llmbo/batch_inferer.py
def check_for_profile(self) -> None:
    """Checks if a profile has been set.

    Raises:
        KeyError: If AWS_PROFILE does not exist in the env.
    """
    if not os.getenv("AWS_PROFILE"):
        self.logger.error("AWS_PROFILE environment variable not set")
        raise KeyError("AWS_PROFILE environment variable not set")

create()

Create a new batch inference job in AWS Bedrock.

Initializes a new model invocation job using the configured parameters and uploaded input data.

Returns:

Name Type Description
dict dict[str, Any]

The complete response from the create_model_invocation_job API call

Raises:

Type Description
RuntimeError

If job creation fails

ClientError

For AWS API errors

ValueError

If required configurations are missing

Note
  • Sets self.job_arn on successful creation
  • Input data must be uploaded to S3 before calling this method
  • Job will timeout after self.time_out_duration_hours
Source code in src/llmbo/batch_inferer.py
def create(self) -> dict[str, Any]:
    """Create a new batch inference job in AWS Bedrock.

    Initializes a new model invocation job using the configured parameters
    and uploaded input data.

    Returns:
        dict: The complete response from the create_model_invocation_job API call

    Raises:
        RuntimeError: If job creation fails
        ClientError: For AWS API errors
        ValueError: If required configurations are missing

    Note:
        - Sets self.job_arn on successful creation
        - Input data must be uploaded to S3 before calling this method
        - Job will timeout after self.time_out_duration_hours
    """
    if self.requests:
        self.logger.info(f"Creating job {self.job_name}")
        response = self.client.create_model_invocation_job(
            jobName=self.job_name,
            roleArn=self.role_arn,
            clientRequestToken="string",
            modelId=self.model_name,
            inputDataConfig={
                "s3InputDataConfig": {
                    "s3InputFormat": "JSONL",
                    "s3Uri": f"{self.bucket_uri}/input/{self.file_name}",
                }
            },
            outputDataConfig={
                "s3OutputDataConfig": {
                    "s3Uri": f"{self.bucket_uri}/output/",
                }
            },
            timeoutDurationInHours=self.time_out_duration_hours,
            tags=[{"key": "bedrock_batch_inference", "value": self.job_name}],
        )

        if response:
            response_status = response["ResponseMetadata"]["HTTPStatusCode"]
            if response_status == 200:
                self.logger.info(f"Job {self.job_name} created successfully")
                self.logger.info(f"Assigned jobArn: {response['jobArn']}")
                self.job_arn = response["jobArn"]
                return response
            else:
                self.logger.error(
                    f"There was an error creating the job {self.job_name},"
                    " non 200 response from bedrock"
                )
                raise RuntimeError(
                    f"There was an error creating the job {self.job_name},"
                    " non 200 response from bedrock"
                )
        else:
            self.logger.error(
                "There was an error creating the job, no response from bedrock"
            )
            raise RuntimeError(
                "There was an error creating the job, no response from bedrock"
            )
    else:
        self.logger.error("There were no prepared requests")
        raise AttributeError("There were no prepared requests")

download_results()

Download batch inference results from S3.

Retrieves both the results and manifest files from S3 once the job has completed. Files are downloaded to: - {job_name}_out.jsonl: Contains model outputs - {job_name}_manifest.jsonl: Contains job statistics

Raises:

Type Description
ClientError

For S3 download failures

ValueError

If job hasn't completed or job_arn isn't set

Note
  • Only downloads if job status is in VALID_FINISHED_STATUSES
  • Files are downloaded to current working directory
  • Existing files will be overwritten
  • Call check_complete() first to ensure job is finished
Source code in src/llmbo/batch_inferer.py
def download_results(self) -> None:
    """Download batch inference results from S3.

    Retrieves both the results and manifest files from S3 once the job
    has completed. Files are downloaded to:
        - {job_name}_out.jsonl: Contains model outputs
        - {job_name}_manifest.jsonl: Contains job statistics

    Raises:
        ClientError: For S3 download failures
        ValueError: If job hasn't completed or job_arn isn't set

    Note:
        - Only downloads if job status is in VALID_FINISHED_STATUSES
        - Files are downloaded to current working directory
        - Existing files will be overwritten
        - Call check_complete() first to ensure job is finished
    """
    if self.check_complete() in VALID_FINISHED_STATUSES:
        file_name_, ext = os.path.splitext(self.file_name)
        self.output_file_name = f"{file_name_}_out{ext}"
        self.manifest_file_name = f"{file_name_}_manifest{ext}"
        self.logger.info(
            f"Job:{self.job_arn} Complete. Downloading results from {self.bucket_name}"
        )
        s3_client = self.session.client("s3")
        s3_client.download_file(
            Bucket=self.bucket_name,
            Key=f"output/{self.unique_id_from_arn}/{self.file_name}.out",
            Filename=self.output_file_name,
        )
        self.logger.info(f"Downloaded results file to {self.output_file_name}")

        s3_client.download_file(
            Bucket=self.bucket_name,
            Key=f"output/{self.unique_id_from_arn}/manifest.json.out",
            Filename=self.manifest_file_name,
        )
        self.logger.info(f"Downloaded manifest file to {self.manifest_file_name}")
    else:
        self.logger.info(
            f"Job:{self.job_arn} was not marked one of {VALID_FINISHED_STATUSES}, could not download."
        )

load_results()

Load batch inference results and manifest from local files.

Reads and parses the output files downloaded from S3, populating: - self.results: List of inference results from the output JSONL file - self.manifest: Statistics about the job execution (total records, success/error counts, etc.)

The method expects two files to exist locally
  • {job_name}_out.jsonl: Contains the model outputs
  • {job_name}_manifest.jsonl: Contains execution statistics

Raises:

Type Description
FileExistsError

If either the results or manifest files are not found locally

Note
  • Must call download_results() before calling this method
  • The manifest provides useful metrics like success rate and token counts
Source code in src/llmbo/batch_inferer.py
def load_results(self) -> None:
    """Load batch inference results and manifest from local files.

    Reads and parses the output files downloaded from S3, populating:
        - self.results: List of inference results from the output JSONL file
        - self.manifest: Statistics about the job execution (total records, success/error counts, etc.)

    The method expects two files to exist locally:
        - {job_name}_out.jsonl: Contains the model outputs
        - {job_name}_manifest.jsonl: Contains execution statistics

    Raises:
        FileExistsError: If either the results or manifest files are not found locally

    Note:
        - Must call download_results() before calling this method
        - The manifest provides useful metrics like success rate and token counts
    """
    if os.path.isfile(self.output_file_name) and os.path.isfile(
        self.manifest_file_name
    ):
        self.results = self._read_jsonl(self.output_file_name)
        self.manifest = Manifest(**self._read_jsonl(self.manifest_file_name)[0])
    else:
        self.logger.error(
            "Result files do not exist, you may need to call .download_results() first."
        )
        raise FileExistsError(
            "Result files do not exist, you may need to call .download_results() first."
        )

poll_progress(poll_interval_seconds=60)

Polls the progress of a job.

Parameters:

Name Type Description Default
poll_interval_seconds int

Number of seconds between checks. Defaults to 60.

60

Returns:

Name Type Description
bool bool

True if job is complete.

Source code in src/llmbo/batch_inferer.py
def poll_progress(self, poll_interval_seconds: int = 60) -> bool:
    """Polls the progress of a job.

    Args:
        poll_interval_seconds (int, optional): Number of seconds between checks. Defaults to 60.

    Returns:
        bool: True if job is complete.
    """
    self.logger.info(f"Polling for progress every {poll_interval_seconds} seconds")
    while not self.check_complete():
        time.sleep(poll_interval_seconds)
    return True

prepare_requests(inputs)

Prepare batch inference requests from a dictionary of model inputs.

Formats model inputs into the required JSONL structure for AWS Bedrock batch processing. Each request is formatted as: { "recordId": str, "modelInput": dict }

Parameters:

Name Type Description Default
inputs Dict[str, ModelInput]

Dictionary mapping record IDs to their corresponding ModelInput configurations. The record IDs will be used to track results.

required

Raises:

Type Description
ValueError

If len(inputs) < 100, as AWS Bedrock requires minimum batch size of 100

Example

inputs = { ... "001": ModelInput( ... messages=[{"role": "user", "content": "Hello"}], ... temperature=0.7 ... ), ... "002": ModelInput( ... messages=[{"role": "user", "content": "Hi"}], ... temperature=0.7 ... ) ... } bi.prepare_requests(inputs)

Note
  • This method must be called before push_requests_to_s3()
  • The prepared requests are stored in self.requests
  • Each ModelInput is converted to a dict using its to_dict() method
Source code in src/llmbo/batch_inferer.py
def prepare_requests(self, inputs: dict[str, ModelInput]) -> None:
    """Prepare batch inference requests from a dictionary of model inputs.

    Formats model inputs into the required JSONL structure for AWS Bedrock
    batch processing. Each request is formatted as:
        {
            "recordId": str,
            "modelInput": dict
        }

    Args:
        inputs (Dict[str, ModelInput]): Dictionary mapping record IDs to their corresponding
            ModelInput configurations. The record IDs will be used to track results.

    Raises:
        ValueError: If len(inputs) < 100, as AWS Bedrock requires minimum batch size of 100

    Example:
        >>> inputs = {
        ...     "001": ModelInput(
        ...         messages=[{"role": "user", "content": "Hello"}],
        ...         temperature=0.7
        ...     ),
        ...     "002": ModelInput(
        ...         messages=[{"role": "user", "content": "Hi"}],
        ...         temperature=0.7
        ...     )
        ... }
        >>> bi.prepare_requests(inputs)

    Note:
        - This method must be called before push_requests_to_s3()
        - The prepared requests are stored in self.requests
        - Each ModelInput is converted to a dict using its to_dict() method
    """
    # TODO: Should I copy these inputs so I dont modify them.
    self.logger.info(f"Preparing {len(inputs)} requests")
    self._check_input_length(inputs)
    self.logger.info("Adding model specific parameters to model_input")
    for id, model_input in inputs.items():
        inputs[id] = self.adapter.prepare_model_input(model_input)

    self.requests = self._to_requests(inputs)

push_requests_to_s3()

Upload batch inference requests to S3.

Writes the prepared requests to a local JSONL file and uploads it to the configured S3 bucket in the 'input/' prefix.

Returns:

Name Type Description
dict dict[str, Any]

The S3 upload response from boto3

Raises:

Type Description
IOError

If local file operations fail

ClientError

If S3 upload fails

AttributeError

If called before prepare_requests()

Note
  • Creates/overwrites files both locally and in S3
  • S3 path: {bucket_name}/input/{job_name}.jsonl
  • Sets Content-Type to 'application/json'
Source code in src/llmbo/batch_inferer.py
def push_requests_to_s3(self) -> dict[str, Any]:
    """Upload batch inference requests to S3.

    Writes the prepared requests to a local JSONL file and uploads it to the
    configured S3 bucket in the 'input/' prefix.

    Returns:
        dict: The S3 upload response from boto3

    Raises:
        IOError: If local file operations fail
        ClientError: If S3 upload fails
        AttributeError: If called before prepare_requests()

    Note:
        - Creates/overwrites files both locally and in S3
        - S3 path: {bucket_name}/input/{job_name}.jsonl
        - Sets Content-Type to 'application/json'
    """
    # do I want to write this file locally? - maybe stream it or write it to
    # temp file instead
    self._write_requests_locally()
    s3_client = self.session.client("s3")
    self.logger.info(f"Pushing {len(self.requests)} requests to {self.bucket_name}")
    response = s3_client.upload_file(
        Filename=self.file_name,
        Bucket=self.bucket_name,
        Key=f"input/{self.file_name}",
        ExtraArgs={"ContentType": "application/json"},
    )
    return response

recover_details_from_job_arn(job_arn, region, session=None) classmethod

Recover a BatchInferer instance from an existing job ARN.

Used to reconstruct a BatchInferer object when the original Python process has terminated but the AWS job is still running or complete.

Parameters:

Name Type Description Default
job_arn str

(str) The AWS ARN of the existing batch inference job

required
region str

(str) the region where the job was scheduled

required
session session

A boto3 session to be used for calls to AWS, If one if not provided a new one will be created

None

Returns:

Name Type Description
BatchInferer BatchInferer

A configured instance with the job's details

Raises:

Type Description
ValueError

If the job cannot be found or response is invalid

Example

job_arn = "arn:aws:bedrock:region:account:job/xyz123" bi = BatchInferer.recover_details_from_job_arn(job_arn) bi.check_complete() 'Completed'

Source code in src/llmbo/batch_inferer.py
@classmethod
def recover_details_from_job_arn(
    cls, job_arn: str, region: str, session: boto3.Session | None = None
) -> "BatchInferer":
    """Recover a BatchInferer instance from an existing job ARN.

    Used to reconstruct a BatchInferer object when the original Python process
    has terminated but the AWS job is still running or complete.

    Args:
        job_arn: (str) The AWS ARN of the existing batch inference job
        region: (str) the region where the job was scheduled
        session (boto3.session, optional): A boto3 session to be used for calls to AWS,
                If one if not provided a new one will be  created

    Returns:
        BatchInferer: A configured instance with the job's details

    Raises:
        ValueError: If the job cannot be found or response is invalid

    Example:
        >>> job_arn = "arn:aws:bedrock:region:account:job/xyz123"
        >>> bi = BatchInferer.recover_details_from_job_arn(job_arn)
        >>> bi.check_complete()
        'Completed'
    """
    cls.logger.info(f"Attempting to Recover BatchInferer from {job_arn}")
    session = session or boto3.Session()
    response = cls.check_for_existing_job(job_arn, region, session)

    try:
        # Extract required parameters from response
        job_name = response["jobName"]
        model_id = response["modelId"]
        bucket_name = response["inputDataConfig"]["s3InputDataConfig"][
            "s3Uri"
        ].split("/")[2]
        role_arn = response["roleArn"]

        # Validate required files exist
        input_file = f"{job_name}.jsonl"
        if not os.path.exists(input_file):
            cls.logger.error(f"Required input file not found: {input_file}")
            raise FileNotFoundError(f"Required input file not found: {input_file}")

        requests = cls._read_jsonl(input_file)

        bi = cls(
            model_name=model_id,
            job_name=job_name,
            region=region,
            bucket_name=bucket_name,
            role_arn=role_arn,
            session=session,
        )
        bi.job_arn = job_arn
        bi.requests = requests
        bi.job_status = response["status"]

        return bi

    except (KeyError, IndexError) as e:
        cls.logger.error(f"Invalid job response format: {str(e)}")
        raise ValueError(f"Invalid job response format: {str(e)}") from e
    except Exception as e:
        cls.logger.error(f"Failed to recover job details: {str(e)}")
        raise RuntimeError(f"Failed to recover job details: {str(e)}") from e

Manifest dataclass

Job manifest details.

Source code in src/llmbo/models.py
@dataclass
class Manifest:
    """Job manifest details."""

    totalRecordCount: int
    processedRecordCount: int
    successRecordCount: int
    errorRecordCount: int
    inputTokenCount: int | None = None
    outputTokenCount: int | None = None

ModelAdapterRegistry

Registry for model provider adapters.

This registry maps model name patterns to their corresponding adapter classes. Users can register custom adapters for new model providers or to override existing implementations.

Example

Register a custom adapter for a new model

ModelAdapterRegistry.register("my-custom-model", MyCustomAdapter)

Source code in src/llmbo/registry.py
class ModelAdapterRegistry:
    """Registry for model provider adapters.

    This registry maps model name patterns to their corresponding adapter classes.
    Users can register custom adapters for new model providers or to override
    existing implementations.

    Example:
        >>> # Register a custom adapter for a new model
        >>> ModelAdapterRegistry.register("my-custom-model", MyCustomAdapter)
    """

    _adapters: list[tuple[Pattern, type[ModelProviderAdapter]]] = []
    logger = logging.getLogger(__name__)

    @classmethod
    def register(cls, pattern: str, adapter_class: type[ModelProviderAdapter]) -> None:
        """Register an adapter class for a specific model pattern.

        Args:
            pattern: Regex pattern to match against model names
            adapter_class: The adapter class to use for matching models

        Raises:
            TypeError: If adapter_class is not a subclass of ModelProviderAdapter
        """
        # Add type validation to ensure adapter_class is a proper ModelProviderAdapter
        if not issubclass(adapter_class, ModelProviderAdapter):
            cls.logger.error(
                f"Adapter class must be a subclass of ModelProviderAdapter, "
                f"got {adapter_class.__name__}"
            )
            raise TypeError(
                f"Adapter class must be a subclass of ModelProviderAdapter, "
                f"got {adapter_class.__name__}"
            )

        compiled_pattern = re.compile(pattern)

        # Check for duplicate pattern and log a warning
        for i, (existing_pattern, _) in enumerate(cls._adapters):
            if existing_pattern.pattern == compiled_pattern.pattern:
                cls.logger.warning(
                    f"Adapter for pattern '{pattern}' is being replaced with {adapter_class.__name__}"
                )
                # Remove the existing adapter with the same pattern
                cls._adapters.pop(i)
                break

        # Add new adapter to the beginning of the list for higher precedence
        cls._adapters.insert(0, (compiled_pattern, adapter_class))
        cls.logger.info(
            f"Registered adapter {adapter_class.__name__} for pattern '{pattern}'"
        )

    @classmethod
    def get_adapter(cls, model_name: str) -> type[ModelProviderAdapter]:
        """Get the appropriate adapter for a model name.

        Args:
            model_name: The model name/ID to find an adapter for

        Returns:
            An adapter class for the given model, or the default adapter if no pattern
            is found
        """
        for pattern, adapter in cls._adapters:
            if pattern.search(model_name):
                return adapter

        cls.logger.warning(
            f"No pattern found for {model_name}, returning default ModelAdapter. "
            "This model is unsupported it may not work as expected.",
        )
        return DefaultAdapter

get_adapter(model_name) classmethod

Get the appropriate adapter for a model name.

Parameters:

Name Type Description Default
model_name str

The model name/ID to find an adapter for

required

Returns:

Type Description
type[ModelProviderAdapter]

An adapter class for the given model, or the default adapter if no pattern

type[ModelProviderAdapter]

is found

Source code in src/llmbo/registry.py
@classmethod
def get_adapter(cls, model_name: str) -> type[ModelProviderAdapter]:
    """Get the appropriate adapter for a model name.

    Args:
        model_name: The model name/ID to find an adapter for

    Returns:
        An adapter class for the given model, or the default adapter if no pattern
        is found
    """
    for pattern, adapter in cls._adapters:
        if pattern.search(model_name):
            return adapter

    cls.logger.warning(
        f"No pattern found for {model_name}, returning default ModelAdapter. "
        "This model is unsupported it may not work as expected.",
    )
    return DefaultAdapter

register(pattern, adapter_class) classmethod

Register an adapter class for a specific model pattern.

Parameters:

Name Type Description Default
pattern str

Regex pattern to match against model names

required
adapter_class type[ModelProviderAdapter]

The adapter class to use for matching models

required

Raises:

Type Description
TypeError

If adapter_class is not a subclass of ModelProviderAdapter

Source code in src/llmbo/registry.py
@classmethod
def register(cls, pattern: str, adapter_class: type[ModelProviderAdapter]) -> None:
    """Register an adapter class for a specific model pattern.

    Args:
        pattern: Regex pattern to match against model names
        adapter_class: The adapter class to use for matching models

    Raises:
        TypeError: If adapter_class is not a subclass of ModelProviderAdapter
    """
    # Add type validation to ensure adapter_class is a proper ModelProviderAdapter
    if not issubclass(adapter_class, ModelProviderAdapter):
        cls.logger.error(
            f"Adapter class must be a subclass of ModelProviderAdapter, "
            f"got {adapter_class.__name__}"
        )
        raise TypeError(
            f"Adapter class must be a subclass of ModelProviderAdapter, "
            f"got {adapter_class.__name__}"
        )

    compiled_pattern = re.compile(pattern)

    # Check for duplicate pattern and log a warning
    for i, (existing_pattern, _) in enumerate(cls._adapters):
        if existing_pattern.pattern == compiled_pattern.pattern:
            cls.logger.warning(
                f"Adapter for pattern '{pattern}' is being replaced with {adapter_class.__name__}"
            )
            # Remove the existing adapter with the same pattern
            cls._adapters.pop(i)
            break

    # Add new adapter to the beginning of the list for higher precedence
    cls._adapters.insert(0, (compiled_pattern, adapter_class))
    cls.logger.info(
        f"Registered adapter {adapter_class.__name__} for pattern '{pattern}'"
    )

ModelInput dataclass

Configuration class for AWS Bedrock model inputs.

This class defines the structure and parameters for model invocation requests following AWS Bedrock's expected format.

See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html

Attributes:

Name Type Description
messages List[dict]

List of message objects with role and content

anthropic_version str

Version string for Anthropic models

max_tokens int

Maximum number of tokens in the response

system Optional[str]

System message for the model

stop_sequences Optional[List[str]]

Custom stop sequences

temperature Optional[float]

Sampling temperature

top_p Optional[float]

Nucleus sampling parameter

top_k Optional[int]

Top-k sampling parameter

tools Optional[List[dict]]

Tool definitions for structured outputs

tool_choice Optional[ToolChoice]

Tool selection configuration

Source code in src/llmbo/models.py
@dataclass
class ModelInput:
    """Configuration class for AWS Bedrock model inputs.

    This class defines the structure and parameters for model invocation requests
    following AWS Bedrock's expected format.

    See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html

    Attributes:
        messages (List[dict]): List of message objects with role and content
        anthropic_version (str): Version string for Anthropic models
        max_tokens (int): Maximum number of tokens in the response
        system (Optional[str]): System message for the model
        stop_sequences (Optional[List[str]]): Custom stop sequences
        temperature (Optional[float]): Sampling temperature
        top_p (Optional[float]): Nucleus sampling parameter
        top_k (Optional[int]): Top-k sampling parameter
        tools (Optional[List[dict]]): Tool definitions for structured outputs
        tool_choice (Optional[ToolChoice]): Tool selection configuration
    """

    # These are required
    messages: list[dict]
    anthropic_version: str = "bedrock-2023-05-31"
    max_tokens: int = 2000

    system: str | None = None
    stop_sequences: list[str] | None = None
    temperature: float | None = None
    top_p: float | None = None
    top_k: int | None = None

    tools: list[dict] | None = None
    tool_choice: ToolChoice | str | None = None

    def to_dict(self):
        """Convert to dict."""
        result = {k: v for k, v in self.__dict__.items() if v is not None}
        if isinstance(self.tool_choice, ToolChoice):
            result["tool_choice"] = self.tool_choice.__dict__
        return result

    def to_json(self):
        """Convert to json string."""
        return json.dumps(self.to_dict())

to_dict()

Convert to dict.

Source code in src/llmbo/models.py
def to_dict(self):
    """Convert to dict."""
    result = {k: v for k, v in self.__dict__.items() if v is not None}
    if isinstance(self.tool_choice, ToolChoice):
        result["tool_choice"] = self.tool_choice.__dict__
    return result

to_json()

Convert to json string.

Source code in src/llmbo/models.py
def to_json(self):
    """Convert to json string."""
    return json.dumps(self.to_dict())

StructuredBatchInferer

Bases: BatchInferer

A specialized BatchInferer that enforces structured outputs using Pydantic models.

Inspired by the instructor package, see: https://python.useinstructor.com/ This class extends BatchInferer to add schema validation and structured output handling using Pydantic models.

Parameters:

Name Type Description Default
output_model BaseModel

A Pydantic model defining the expected output structure

required
model_name str

The name/ID of the AWS Bedrock model to use

required
bucket_name str

The S3 bucket name for storing input/output data

required
region str

The region to run the batch inference job in.

required
job_name str

A unique name for the batch inference job

required
role_arn str

The AWS IAM role ARN with necessary permissions

required
time_out_duration_hours int

Maximum job runtime in hours. Defaults to 24.

24
session Session

A boto3 session to be used for AWS API calls. If not provided, a new session will be created.

None
Source code in src/llmbo/structured_batch_inferer.py
class StructuredBatchInferer(BatchInferer):
    """A specialized BatchInferer that enforces structured outputs using Pydantic models.

    Inspired by the instructor package, see: https://python.useinstructor.com/
    This class extends BatchInferer to add schema validation and structured output
    handling using Pydantic models.

    Args:
        output_model (BaseModel): A Pydantic model defining the expected output structure
        model_name (str): The name/ID of the AWS Bedrock model to use
        bucket_name (str): The S3 bucket name for storing input/output data
        region (str): The region to run the batch inference job in.
        job_name (str): A unique name for the batch inference job
        role_arn (str): The AWS IAM role ARN with necessary permissions
        time_out_duration_hours (int, optional): Maximum job runtime in hours. Defaults to 24.
        session (boto3.Session, optional): A boto3 session to be used for AWS API calls.
                                           If not provided, a new session will be created.


    """

    logger = logging.getLogger(f"{__name__}.StructuredBatchInferer")

    def __init__(
        self,
        output_model: type[BaseModel],
        model_name: str,  # this should be an enum...
        bucket_name: str,
        region: str,
        job_name: str,
        role_arn: str,
        time_out_duration_hours: int = 24,
        session: boto3.Session | None = None,
    ):
        """Initialize a StructuredBatchInferer for schema-validated batch processing.

        Creates a batch inference manager that enforces structured outputs using
        a Pydantic model schema. Automatically configures the model to use tools
        for enforcing the output structure.

        Args:
            output_model (BaseModel): Pydantic model class defining the expected output structure
            model_name (str): The AWS Bedrock model identifier
            bucket_name (str): Name of the S3 bucket for storing job inputs and outputs
            region (str): Region of the LLM must match the bucket
            job_name (str): Unique identifier for this batch job
            role_arn (str): AWS IAM role ARN with permissions for Bedrock and S3 access
            time_out_duration_hours (int): Number of hours before the job times out
            session (boto3.Session, optional): A boto3 session to be used for AWS API calls. If not provided, a new session will be created.

        Raises:
            KeyError: If AWS_PROFILE environment variable is not set
            ValueError: If the provided role_arn doesn't exist or is invalid

        Example:
            >>> class PersonInfo(BaseModel):
            ...     name: str
            ...     age: int
            ...
            >>> sbi = StructuredBatchInferer(
            ...     output_model=PersonInfo,
            ...     model_name="anthropic.claude-3-haiku-20240307-v1:0",
            ...     bucket_name="my-inference-bucket",
            ...     job_name="structured-batch-2024",
            ...     role_arn="arn:aws:iam::123456789012:role/BedrockBatchRole"
            ... )

        Note:
            - Converts the Pydantic model into a tool definition for the LLM
            - All results will be validated against the provided schema
            - Failed schema validations will raise errors during result processing
            - Inherits all base BatchInferer functionality
        """
        self.output_model = output_model

        self.logger.info(
            f"Initialized StructuredBatchInferer with {output_model.__name__} schema"
        )

        super().__init__(
            model_name=model_name,
            bucket_name=bucket_name,
            region=region,
            job_name=job_name,
            role_arn=role_arn,
            time_out_duration_hours=time_out_duration_hours,
            session=session,
        )

    def prepare_requests(self, inputs: dict[str, ModelInput]):
        """Prepare structured batch inference requests with tool configurations.

        Extends the base preparation by adding tool definitions and tool choice
        parameters to each ModelInput. The tool definition is derived from the
        Pydantic output_model specified during initialization.

        Args:
            inputs (Dict[str, ModelInput]): Dictionary mapping record IDs to their corresponding
                ModelInput configurations. The record IDs will be used to track results.

        Raises:
            ValueError: If len(inputs) < 100, as AWS Bedrock requires minimum batch size of 100

        Example:
            >>> class PersonInfo(BaseModel):
            ...     name: str
            ...     age: int
            >>> sbi = StructuredBatchInferer(output_model=PersonInfo, ...)
            >>> inputs = {
            ...     "001": ModelInput(
            ...         messages=[{"role": "user", "content": "John is 25 years old"}],
            ...     )
            ... }
            >>> sbi.prepare_requests(inputs)

        Note:
            - Automatically adds the output_model schema as a tool definition
            - Sets tool_choice to force use of the defined schema
            - Original ModelInputs are modified to include tool configurations
        """
        self.logger.info(f"Adding tool {self.output_model.__name__} to model input")
        self._check_input_length(inputs)
        for id, model_input in inputs.items():
            inputs[id] = self.adapter.prepare_model_input(
                model_input, self.output_model
            )

        self.requests = self._to_requests(inputs)

    def load_results(self):
        """Load and validate batch inference results against the output schema.

        Reads the output files downloaded from S3 and validates each result against
        the Pydantic output_model specified during initialization. Populates:
            - self.results: Raw inference results from the output JSONL file
            - self.manifest: Statistics about the job execution
            - self.instances: List of validated Pydantic model instances

        Raises:
            FileExistsError: If either the results or manifest files are not found locally
            ValueError: If any result fails schema validation or tool use validation

        Note:
            - Must call download_results() before calling this method
            - All results must conform to the specified output_model schema
            - Results must show successful tool use
        """
        super().load_results()
        self.instances = [
            {
                "recordId": result["recordId"],
                "outputModel": self.adapter.validate_result(
                    result["modelOutput"], self.output_model
                ),
            }
            if result.get("modelOutput")
            else None
            for result in self.results
        ]

    @classmethod
    def recover_details_from_job_arn(
        cls,
        job_arn: str,
        region: str,
        session: boto3.Session | None = None,
    ) -> "StructuredBatchInferer":
        """Placeholder method for interface consistency.

        This method exists to maintain compatibility with the parent class but
        is not implemented for structured jobs. Use `recover_structured_job`
        instead.

        Raises:
            NotImplementedError: Always raised when called.
        """
        raise NotImplementedError(
            "Cannot recover structured job without output_model. "
            "Use recover_structured_job instead."
        )

    @classmethod
    def recover_structured_job(
        cls,
        job_arn: str,
        region: str,
        output_model: type[BaseModel],
        session: boto3.Session | None = None,
    ) -> "StructuredBatchInferer":
        """Recover a StructuredBatchInferer instance from an existing job ARN.

        Used to reconstruct a StructuredBatchInferer object when the original Python
        process has terminated but the AWS job is still running or complete.

        Args:
            job_arn: (str) The AWS ARN of the existing batch inference job
            region: (str) the region where the job was scheduled
            output_model: (Type[BaseModel]) A pydantic model describing the required output
            session (boto3.Session, optional): A boto3 session to be used for AWS API calls.
                                           If not provided, a new session will be created.

        Returns:
            StructuredBatchInferer: A configured instance with the job's details

        Raises:
            ValueError: If the job cannot be found or response is invalid

        Example:
            >>> job_arn = "arn:aws:bedrock:region:account:job/xyz123"
            >>> region = us-east-1"
            >>> sbi = StructuredBatchInferer.recover_details_from_job_arn(job_arn, region, some_model)
            >>> sbi.check_complete()
            'Completed'
        """
        cls.logger.info(f"Attempting to Recover BatchInferer from {job_arn}")
        session = session or boto3.Session()
        response = cls.check_for_existing_job(job_arn, region, session)

        try:
            # Extract required parameters from response
            job_name = response["jobName"]
            model_id = response["modelId"]
            bucket_name = response["inputDataConfig"]["s3InputDataConfig"][
                "s3Uri"
            ].split("/")[2]
            role_arn = response["roleArn"]

            # Validate required files exist
            input_file = f"{job_name}.jsonl"
            if not os.path.exists(input_file):
                cls.logger.error(f"Required input file not found: {input_file}")
                raise FileNotFoundError(f"Required input file not found: {input_file}")

            requests = cls._read_jsonl(input_file)

            sbi = cls(
                model_name=model_id,
                output_model=output_model,
                job_name=job_name,
                region=region,
                bucket_name=bucket_name,
                role_arn=role_arn,
                session=session,
            )
            sbi.job_arn = job_arn
            sbi.requests = requests
            sbi.job_status = response["status"]

            return sbi

        except (KeyError, IndexError) as e:
            cls.logger.error(f"Invalid job response format: {str(e)}")
            raise ValueError(f"Invalid job response format: {str(e)}") from e
        except Exception as e:
            cls.logger.error(f"Failed to recover job details: {str(e)}")
            raise RuntimeError(f"Failed to recover job details: {str(e)}") from e

__init__(output_model, model_name, bucket_name, region, job_name, role_arn, time_out_duration_hours=24, session=None)

Initialize a StructuredBatchInferer for schema-validated batch processing.

Creates a batch inference manager that enforces structured outputs using a Pydantic model schema. Automatically configures the model to use tools for enforcing the output structure.

Parameters:

Name Type Description Default
output_model BaseModel

Pydantic model class defining the expected output structure

required
model_name str

The AWS Bedrock model identifier

required
bucket_name str

Name of the S3 bucket for storing job inputs and outputs

required
region str

Region of the LLM must match the bucket

required
job_name str

Unique identifier for this batch job

required
role_arn str

AWS IAM role ARN with permissions for Bedrock and S3 access

required
time_out_duration_hours int

Number of hours before the job times out

24
session Session

A boto3 session to be used for AWS API calls. If not provided, a new session will be created.

None

Raises:

Type Description
KeyError

If AWS_PROFILE environment variable is not set

ValueError

If the provided role_arn doesn't exist or is invalid

Example

class PersonInfo(BaseModel): ... name: str ... age: int ... sbi = StructuredBatchInferer( ... output_model=PersonInfo, ... model_name="anthropic.claude-3-haiku-20240307-v1:0", ... bucket_name="my-inference-bucket", ... job_name="structured-batch-2024", ... role_arn="arn:aws:iam::123456789012:role/BedrockBatchRole" ... )

Note
  • Converts the Pydantic model into a tool definition for the LLM
  • All results will be validated against the provided schema
  • Failed schema validations will raise errors during result processing
  • Inherits all base BatchInferer functionality
Source code in src/llmbo/structured_batch_inferer.py
def __init__(
    self,
    output_model: type[BaseModel],
    model_name: str,  # this should be an enum...
    bucket_name: str,
    region: str,
    job_name: str,
    role_arn: str,
    time_out_duration_hours: int = 24,
    session: boto3.Session | None = None,
):
    """Initialize a StructuredBatchInferer for schema-validated batch processing.

    Creates a batch inference manager that enforces structured outputs using
    a Pydantic model schema. Automatically configures the model to use tools
    for enforcing the output structure.

    Args:
        output_model (BaseModel): Pydantic model class defining the expected output structure
        model_name (str): The AWS Bedrock model identifier
        bucket_name (str): Name of the S3 bucket for storing job inputs and outputs
        region (str): Region of the LLM must match the bucket
        job_name (str): Unique identifier for this batch job
        role_arn (str): AWS IAM role ARN with permissions for Bedrock and S3 access
        time_out_duration_hours (int): Number of hours before the job times out
        session (boto3.Session, optional): A boto3 session to be used for AWS API calls. If not provided, a new session will be created.

    Raises:
        KeyError: If AWS_PROFILE environment variable is not set
        ValueError: If the provided role_arn doesn't exist or is invalid

    Example:
        >>> class PersonInfo(BaseModel):
        ...     name: str
        ...     age: int
        ...
        >>> sbi = StructuredBatchInferer(
        ...     output_model=PersonInfo,
        ...     model_name="anthropic.claude-3-haiku-20240307-v1:0",
        ...     bucket_name="my-inference-bucket",
        ...     job_name="structured-batch-2024",
        ...     role_arn="arn:aws:iam::123456789012:role/BedrockBatchRole"
        ... )

    Note:
        - Converts the Pydantic model into a tool definition for the LLM
        - All results will be validated against the provided schema
        - Failed schema validations will raise errors during result processing
        - Inherits all base BatchInferer functionality
    """
    self.output_model = output_model

    self.logger.info(
        f"Initialized StructuredBatchInferer with {output_model.__name__} schema"
    )

    super().__init__(
        model_name=model_name,
        bucket_name=bucket_name,
        region=region,
        job_name=job_name,
        role_arn=role_arn,
        time_out_duration_hours=time_out_duration_hours,
        session=session,
    )

load_results()

Load and validate batch inference results against the output schema.

Reads the output files downloaded from S3 and validates each result against the Pydantic output_model specified during initialization. Populates: - self.results: Raw inference results from the output JSONL file - self.manifest: Statistics about the job execution - self.instances: List of validated Pydantic model instances

Raises:

Type Description
FileExistsError

If either the results or manifest files are not found locally

ValueError

If any result fails schema validation or tool use validation

Note
  • Must call download_results() before calling this method
  • All results must conform to the specified output_model schema
  • Results must show successful tool use
Source code in src/llmbo/structured_batch_inferer.py
def load_results(self):
    """Load and validate batch inference results against the output schema.

    Reads the output files downloaded from S3 and validates each result against
    the Pydantic output_model specified during initialization. Populates:
        - self.results: Raw inference results from the output JSONL file
        - self.manifest: Statistics about the job execution
        - self.instances: List of validated Pydantic model instances

    Raises:
        FileExistsError: If either the results or manifest files are not found locally
        ValueError: If any result fails schema validation or tool use validation

    Note:
        - Must call download_results() before calling this method
        - All results must conform to the specified output_model schema
        - Results must show successful tool use
    """
    super().load_results()
    self.instances = [
        {
            "recordId": result["recordId"],
            "outputModel": self.adapter.validate_result(
                result["modelOutput"], self.output_model
            ),
        }
        if result.get("modelOutput")
        else None
        for result in self.results
    ]

prepare_requests(inputs)

Prepare structured batch inference requests with tool configurations.

Extends the base preparation by adding tool definitions and tool choice parameters to each ModelInput. The tool definition is derived from the Pydantic output_model specified during initialization.

Parameters:

Name Type Description Default
inputs Dict[str, ModelInput]

Dictionary mapping record IDs to their corresponding ModelInput configurations. The record IDs will be used to track results.

required

Raises:

Type Description
ValueError

If len(inputs) < 100, as AWS Bedrock requires minimum batch size of 100

Example

class PersonInfo(BaseModel): ... name: str ... age: int sbi = StructuredBatchInferer(output_model=PersonInfo, ...) inputs = { ... "001": ModelInput( ... messages=[{"role": "user", "content": "John is 25 years old"}], ... ) ... } sbi.prepare_requests(inputs)

Note
  • Automatically adds the output_model schema as a tool definition
  • Sets tool_choice to force use of the defined schema
  • Original ModelInputs are modified to include tool configurations
Source code in src/llmbo/structured_batch_inferer.py
def prepare_requests(self, inputs: dict[str, ModelInput]):
    """Prepare structured batch inference requests with tool configurations.

    Extends the base preparation by adding tool definitions and tool choice
    parameters to each ModelInput. The tool definition is derived from the
    Pydantic output_model specified during initialization.

    Args:
        inputs (Dict[str, ModelInput]): Dictionary mapping record IDs to their corresponding
            ModelInput configurations. The record IDs will be used to track results.

    Raises:
        ValueError: If len(inputs) < 100, as AWS Bedrock requires minimum batch size of 100

    Example:
        >>> class PersonInfo(BaseModel):
        ...     name: str
        ...     age: int
        >>> sbi = StructuredBatchInferer(output_model=PersonInfo, ...)
        >>> inputs = {
        ...     "001": ModelInput(
        ...         messages=[{"role": "user", "content": "John is 25 years old"}],
        ...     )
        ... }
        >>> sbi.prepare_requests(inputs)

    Note:
        - Automatically adds the output_model schema as a tool definition
        - Sets tool_choice to force use of the defined schema
        - Original ModelInputs are modified to include tool configurations
    """
    self.logger.info(f"Adding tool {self.output_model.__name__} to model input")
    self._check_input_length(inputs)
    for id, model_input in inputs.items():
        inputs[id] = self.adapter.prepare_model_input(
            model_input, self.output_model
        )

    self.requests = self._to_requests(inputs)

recover_details_from_job_arn(job_arn, region, session=None) classmethod

Placeholder method for interface consistency.

This method exists to maintain compatibility with the parent class but is not implemented for structured jobs. Use recover_structured_job instead.

Raises:

Type Description
NotImplementedError

Always raised when called.

Source code in src/llmbo/structured_batch_inferer.py
@classmethod
def recover_details_from_job_arn(
    cls,
    job_arn: str,
    region: str,
    session: boto3.Session | None = None,
) -> "StructuredBatchInferer":
    """Placeholder method for interface consistency.

    This method exists to maintain compatibility with the parent class but
    is not implemented for structured jobs. Use `recover_structured_job`
    instead.

    Raises:
        NotImplementedError: Always raised when called.
    """
    raise NotImplementedError(
        "Cannot recover structured job without output_model. "
        "Use recover_structured_job instead."
    )

recover_structured_job(job_arn, region, output_model, session=None) classmethod

Recover a StructuredBatchInferer instance from an existing job ARN.

Used to reconstruct a StructuredBatchInferer object when the original Python process has terminated but the AWS job is still running or complete.

Parameters:

Name Type Description Default
job_arn str

(str) The AWS ARN of the existing batch inference job

required
region str

(str) the region where the job was scheduled

required
output_model type[BaseModel]

(Type[BaseModel]) A pydantic model describing the required output

required
session Session

A boto3 session to be used for AWS API calls. If not provided, a new session will be created.

None

Returns:

Name Type Description
StructuredBatchInferer StructuredBatchInferer

A configured instance with the job's details

Raises:

Type Description
ValueError

If the job cannot be found or response is invalid

Example

job_arn = "arn:aws:bedrock:region:account:job/xyz123" region = us-east-1" sbi = StructuredBatchInferer.recover_details_from_job_arn(job_arn, region, some_model) sbi.check_complete() 'Completed'

Source code in src/llmbo/structured_batch_inferer.py
@classmethod
def recover_structured_job(
    cls,
    job_arn: str,
    region: str,
    output_model: type[BaseModel],
    session: boto3.Session | None = None,
) -> "StructuredBatchInferer":
    """Recover a StructuredBatchInferer instance from an existing job ARN.

    Used to reconstruct a StructuredBatchInferer object when the original Python
    process has terminated but the AWS job is still running or complete.

    Args:
        job_arn: (str) The AWS ARN of the existing batch inference job
        region: (str) the region where the job was scheduled
        output_model: (Type[BaseModel]) A pydantic model describing the required output
        session (boto3.Session, optional): A boto3 session to be used for AWS API calls.
                                       If not provided, a new session will be created.

    Returns:
        StructuredBatchInferer: A configured instance with the job's details

    Raises:
        ValueError: If the job cannot be found or response is invalid

    Example:
        >>> job_arn = "arn:aws:bedrock:region:account:job/xyz123"
        >>> region = us-east-1"
        >>> sbi = StructuredBatchInferer.recover_details_from_job_arn(job_arn, region, some_model)
        >>> sbi.check_complete()
        'Completed'
    """
    cls.logger.info(f"Attempting to Recover BatchInferer from {job_arn}")
    session = session or boto3.Session()
    response = cls.check_for_existing_job(job_arn, region, session)

    try:
        # Extract required parameters from response
        job_name = response["jobName"]
        model_id = response["modelId"]
        bucket_name = response["inputDataConfig"]["s3InputDataConfig"][
            "s3Uri"
        ].split("/")[2]
        role_arn = response["roleArn"]

        # Validate required files exist
        input_file = f"{job_name}.jsonl"
        if not os.path.exists(input_file):
            cls.logger.error(f"Required input file not found: {input_file}")
            raise FileNotFoundError(f"Required input file not found: {input_file}")

        requests = cls._read_jsonl(input_file)

        sbi = cls(
            model_name=model_id,
            output_model=output_model,
            job_name=job_name,
            region=region,
            bucket_name=bucket_name,
            role_arn=role_arn,
            session=session,
        )
        sbi.job_arn = job_arn
        sbi.requests = requests
        sbi.job_status = response["status"]

        return sbi

    except (KeyError, IndexError) as e:
        cls.logger.error(f"Invalid job response format: {str(e)}")
        raise ValueError(f"Invalid job response format: {str(e)}") from e
    except Exception as e:
        cls.logger.error(f"Failed to recover job details: {str(e)}")
        raise RuntimeError(f"Failed to recover job details: {str(e)}") from e

ToolChoice dataclass

Toolchoice details.

Source code in src/llmbo/models.py
@dataclass
class ToolChoice:
    """Toolchoice details."""

    type: Literal["any", "tool", "auto"]
    name: str | None = None