Create publication-quality figures and graphs for scientific analysis. Use when creating bar charts, ROC curves, confusion matrices, scatter plots, heatmaps, 3D surface plots, or any data visualization. Supports PDF and PNG output with consistent styling.
Create publication-quality figures for scientific papers and reports.
IMPORTANT: When generating figures using generate_paper_figures.py, the PowerPoint is automatically updated:
# Generate figure (full reprocessing from raw data)
python generate_paper_figures.py --figure 3
# Quick mode: skip reprocessing, use existing processed data + PNGs
python generate_paper_figures.py --figure 2 --quick
# Skip PowerPoint update (figures only)
python generate_paper_figures.py --figure 3 --no-pptx
The script automatically:
Output/PowerPoint_Figures/Fig_X/Output/PowerPoint_Figures/Cardiac_RODEO_Tracked.pptxSlide 2 (Fig 2) — Runtime Panel Discovery (NOT hardcoded rIds):
Slide 2 uses named groups Panel_2a through Panel_2l in the PPTX.
The script discovers rIds at runtime by parsing group names — this survives
PowerPoint re-saves that renumber rIds. Defined in SLIDE2_PANEL_MAP:
| Panel | Filename in Fig_2/ | Description |
|---|---|---|
| a–f | external | Not replaced (plate photo, microscopy, diagrams) |
| d | Fig_2i.png | SNR Quality Analysis |
| g | Fig_2_Epirubicin_O2.png | Metabolic Dose Dependent Response (averaged) |
| h | Fig_2_Epirubicin_TC50.png | Epirubicin TC50 (32h) |
| i | Fig_2_Epirubicin_O2_heatmap.png | Epirubicin O2 heatmap (LOWESS w=16) |
| j | Fig_2_Mexiletine_Contractility.png | Mexiletine Contractility 2D dose-response |
| k | Fig_2k_Mexiletine_Waveforms.png | Mexiletine heart rate waveforms (48h) |
| l | Fig_2_Mexiletine_Contractility_heatmap.png | Mexiletine Contractility heatmap |
To change which image goes to which panel, edit SLIDE2_PANEL_MAP in
generate_paper_figures.py. The key is the panel letter, the value is the
filename (relative to Output/PowerPoint_Figures/Fig_2/).
Slide 2 is in MANUAL_GROUP_SLIDES — the script does NOT auto-group/label
this slide. Panel groups and labels are managed manually in PowerPoint.
The _discover_slide2_rids() function handles the rId lookup.
Other slides — position-based or explicit rId mapping:
| Figure | PowerPoint Images | Slide |
|---|---|---|
| Fig_3 (a-e) | SLIDE3_RID_MAP (explicit) | 3 |
| Fig_6 (a-h) | position-based | 6 |
| Fig_7 (a-h) | position-based | 7 |
| Fig_8 (a-f) | position-based | 8 |
CRITICAL — Cross-Slide Panel Alignment (Figures 6, 7, 8): Figures 6, 7, and 8 show the same panel layout (a–f) for different drug categories (Arrhythmia+HeartDamage, ADMET comparison, SwissADME comparison). Panels a–f MUST be at identical positions across all three slides so they align when flipping between pages.
Reference positions (from Fig 7, the canonical source):
| Panel | x | y | w | h |
|---|---|---|---|---|
| a | 0.30 | 0.80 | 1.70 | 1.70 |
| b | 2.15 | 0.80 | 1.70 | 1.70 |
| c | 4.00 | 0.80 | 1.70 | 1.70 |
| d | 0.30 | 2.65 | 1.55 | 1.87 |
| e | 2.00 | 2.65 | 2.47 | 1.78 |
| f | 0.30 | 4.67 | 3.40 | 1.70 |
When regenerating or repositioning panels on slides 6–8, always use these exact positions. Extract and verify from the PPTX directly (not just slide_layout.json) since manual edits in PowerPoint can introduce sub-pixel drift.
fill_between)Blues cmap)RdBu_r), performance (RdYlGn), SHAP (coolwarm)Use these colors consistently across ALL figures:
PRIMARY_COLORS = {
'beige': '#E3D5B2', # Warm neutral - backgrounds, secondary elements
'blue': '#6C92ED', # Primary accent - main data series
'pink': '#ECA0C0', # Secondary accent - comparison data
'orange': '#F8B274', # Tertiary accent - highlights
}
# Extended palette (same theme)
EXTENDED_COLORS = {
'dark_blue': '#4A6FBF', # Darker blue for emphasis
'light_pink': '#F5C6D6', # Lighter pink for fills
'coral': '#E89B7A', # Warm coral
'sage': '#A8C4A2', # Muted green
'lavender': '#B8A9D9', # Soft purple
'cream': '#F5EFE0', # Light background
}
colors = {
'CNN (DIQT Transfer)': '#ECA0C0', # Pink
'CNN (5-fold on 25)': '#6C92ED', # Blue
'Organoid (5-fold)': '#F8B274', # Orange
'ADMET-AI': '#6C92ED', # Blue
'SwissADME': '#ECA0C0', # Pink
}
metric_colors = {
'Accuracy': '#6C92ED', # Blue
'F1 Score': '#F8B274', # Orange
'MCC': '#ECA0C0', # Pink
'AUC': '#4A6FBF', # Dark Blue
'Sensitivity': '#E89B7A', # Coral
'Specificity': '#A8C4A2', # Sage
}
Used in the R² comparison bar chart and any figure referencing equations by name. Colors are assigned in rainbow order from best R² (red) to worst (pink).
EQUATION_COLORS = {
'Dual Exponential': '#d62728', # Red
'Hormesis Hill': '#e6550d', # Red-Orange
'PKPD Elimination': '#ff7f0e', # Orange
'Biphasic Response': '#ffc107', # Amber
'Dual Hill Hormesis': '#8bc34a', # Yellow-Green
'Modified Hill': '#2ca02c', # Green
'Adaptive Response': '#00897b', # Teal
'Gaussian Ridge': '#17becf', # Cyan
'Bivariate Gaussian': '#1f77b4', # Blue
'Gaussian-Hill Hybrid': '#5c6bc0', # Indigo
'Recovery Model': '#9467bd', # Purple
'Cumulative Exposure': '#e377c2', # Pink
}
# Internal code names → display names
EQUATION_DISPLAY_NAMES = {
'dual_exponential': 'Dual Exponential',
'bivariate_gaussian': 'Bivariate Gaussian',
'gaussian_hill_hybrid': 'Gaussian-Hill Hybrid',
'modified_hill_hormesis':'Hormesis Hill',
'gaussian_ridge': 'Gaussian Ridge',
'adaptive_response': 'Adaptive Response',
'biphasic_response': 'Biphasic Response',
'cumulative_exposure': 'Cumulative Exposure',
'recovery_model': 'Recovery Model',
'modified_hill_simple': 'Modified Hill',
'pkpd_elimination': 'PKPD Elimination',
'hormesis_v0': 'Dual Hill Hormesis',
}
class_colors = {
'Positive': '#ECA0C0', # Pink
'Negative': '#6C92ED', # Blue
'True Positive': '#6C92ED',
'False Positive': '#F8B274',
'True Negative': '#A8C4A2',
'False Negative': '#E89B7A',
}
When the user asks to change figure dimensions (e.g., "make it smaller", "1.7 inches"), they mean the ACTUAL GRAPH/PLOT size, not the overall image canvas.
DO NOT:
DO:
figsize=(width, height) to directly control the plot dimensionsbbox_inches='tight' when saving to remove excess whitespace# CORRECT: figsize controls the actual graph size
fig, ax = plt.subplots(figsize=(1.7, 1.7)) # Graph is 1.7" x 1.7"
plt.savefig('output.png', dpi=600, bbox_inches='tight')
# WRONG: Large canvas with small graph inside
fig, ax = plt.subplots(figsize=(4, 4)) # 4" canvas
ax.set_position([0.3, 0.3, 0.4, 0.4]) # Graph only 1.6" - DON'T DO THIS
The figsize parameter IS the graph size (plus minimal axis labels/title). When asked for "1.7 inch square", use figsize=(1.7, 1.7).
# Save PDF for LaTeX/publication
plt.savefig('Output/path/figures/figure_name.pdf', bbox_inches='tight')
# Save PNG at 600 DPI for high-quality viewing
plt.savefig('Output/path/Figure_Name.png', dpi=600, bbox_inches='tight')
plt.close()
# Single panel
fig, ax = plt.subplots(figsize=(8, 6))
# Side-by-side panels
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
# Three panels horizontal
fig, axes = plt.subplots(1, 3, figsize=(14, 4))
# Grid layout
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
# Wide scatter plot (per-drug)
fig, ax = plt.subplots(figsize=(16, 6))
# Default square graph (for bar charts, scatter plots, Accuracy vs AUC, etc.)
SQUARE_SIZE = 1.7 # inches - standard square panel
fig, ax = plt.subplots(figsize=(SQUARE_SIZE, SQUARE_SIZE))
# Heatmap (MANDATORY size - 1:2 ratio)
HEATMAP_HEIGHT = 1.7 # inches
HEATMAP_WIDTH = 3.4 # inches (2x height)
fig, ax = plt.subplots(figsize=(HEATMAP_WIDTH, HEATMAP_HEIGHT))
# Accuracy vs AUC scatter (use square size, same as bar charts)
fig, ax = plt.subplots(figsize=(1.7, 1.7)) # 1:1 square
# Or for 3-panel comparison with shared axis:
fig, axes = plt.subplots(1, 3, figsize=(5.1, 1.7), sharey=True) # 3 × 1.7" width
# Bar chart (use square size)
fig, ax = plt.subplots(figsize=(1.7, 1.7)) # Standard square
import numpy as np
import matplotlib.pyplot as plt
models = ['Model A', 'Model B', 'Model C']
metrics = ['Accuracy', 'F1 Score', 'MCC']
# Mean values
data = np.array([
[0.56, 0.72, 0.00], # Model A
[0.68, 0.73, 0.34], # Model B
[0.74, 0.77, 0.46], # Model C
])
# Standard deviations (ALWAYS include error bars)
data_std = np.array([
[0.05, 0.04, 0.00], # Model A
[0.03, 0.05, 0.08], # Model B
[0.04, 0.03, 0.06], # Model C
])
x = np.arange(len(models))
width = 0.25
colors = ['#6C92ED', '#F8B274', '#ECA0C0'] # Use project color palette
# Use 1.7" square for single bar charts (or scale up for grouped)
SQUARE_SIZE = 1.7
fig, ax = plt.subplots(figsize=(SQUARE_SIZE * 2, SQUARE_SIZE)) # 2:1 ratio for grouped
for i, (metric, color) in enumerate(zip(metrics, colors)):
bars = ax.bar(x + i*width - width, data[:, i], width,
yerr=data_std[:, i], capsize=4, # Error bars with caps
label=metric, color=color, edgecolor='black')
# Add value labels above error bars
for j, bar in enumerate(bars):
height = bar.get_height()
err = data_std[j, i]
ax.annotate(f'{height:.2f}',
xy=(bar.get_x() + bar.get_width()/2, height + err),
xytext=(0, 3), textcoords='offset points',
ha='center', va='bottom', fontsize=9, fontweight='bold')
ax.set_ylabel('Score', fontsize=9)
ax.set_title('Performance Comparison', fontsize=10, fontweight='bold')
ax.set_xticks(x)
ax.set_xticklabels(models, fontsize=8)
ax.tick_params(axis='both', labelsize=8)
ax.legend(loc='upper left', fontsize=8)
ax.set_ylim(0, 1.1) # Extra space for error bars
ax.grid(axis='y', alpha=0.3)
plt.tight_layout()
CRITICAL: ALL ROC curves MUST include shaded confidence bands. ROC curves without shaded uncertainty regions are NOT acceptable for publication figures.
What the shaded band represents:
import figure_config # FIRST LINE - registers Helvetica
from sklearn.metrics import roc_curve, auc
import numpy as np
import matplotlib.pyplot as plt
def bootstrap_roc_stats(y_true, y_prob, n_boot=300, seed=42):
"""
Bootstrap ROC statistics for confidence intervals.
Returns:
mean_fpr: Common FPR points (0 to 1, 100 points)
mean_tpr: Mean TPR at each FPR point
std_tpr: Standard deviation of TPR (for shaded band)
auc_mean: Mean AUC across bootstrap samples
auc_std: Standard deviation of AUC
"""
rng = np.random.default_rng(seed)
y_true, y_prob = np.asarray(y_true), np.asarray(y_prob)
n = len(y_true)
mean_fpr = np.linspace(0, 1, 100)
tprs, aucs = [], []
for _ in range(n_boot):
idx = rng.integers(0, n, n)
if len(np.unique(y_true[idx])) < 2:
continue
fpr, tpr, _ = roc_curve(y_true[idx], y_prob[idx])
tpr_interp = np.interp(mean_fpr, fpr, tpr)
tpr_interp[0] = 0.0
tprs.append(tpr_interp)
aucs.append(auc(mean_fpr, tpr_interp))
tprs = np.array(tprs)
return mean_fpr, tprs.mean(axis=0), tprs.std(axis=0), np.mean(aucs), np.std(aucs)
def plot_roc_with_bands(ax, mean_fpr, mean_tpr, std_tpr, auc_val, auc_std, color, label):
"""
Plot ROC curve with MANDATORY shaded confidence band.
The shaded region represents ±1 std of TPR at each FPR point,
showing the uncertainty from bootstrap resampling.
"""
# Plot the mean ROC curve
ax.plot(mean_fpr, mean_tpr, color=color, lw=2,
label=f'{label} (AUC={auc_val:.2f}±{auc_std:.2f})')
# MANDATORY: Shaded confidence band between upper and lower bounds
lower_bound = np.maximum(mean_tpr - std_tpr, 0) # Clamp to 0 minimum
upper_bound = np.minimum(mean_tpr + std_tpr, 1) # Clamp to 1 maximum
ax.fill_between(
mean_fpr, # X values (FPR points)
lower_bound, # Lower edge of shaded region
upper_bound, # Upper edge of shaded region
color=color,
alpha=0.2, # Semi-transparent shading
edgecolor='none' # No edge line on the shaded region
)
# COMPLETE USAGE EXAMPLE
fig, ax = plt.subplots(figsize=(7, 6))
# Example: Plot ROC for multiple models
models_data = [
('Model A', y_true_a, y_prob_a, '#6C92ED'), # Blue
('Model B', y_true_b, y_prob_b, '#ECA0C0'), # Pink
('Model C', y_true_c, y_prob_c, '#F8B274'), # Orange
]
for label, y_true, y_prob, color in models_data:
# Compute bootstrap statistics
mean_fpr, mean_tpr, std_tpr, auc_val, auc_std = bootstrap_roc_stats(y_true, y_prob)
# Plot with MANDATORY shaded band
plot_roc_with_bands(ax, mean_fpr, mean_tpr, std_tpr, auc_val, auc_std, color, label)
# Random classifier baseline (diagonal)
ax.plot([0, 1], [0, 1], 'k--', lw=1, label='Random (AUC=0.50)')
ax.set_xlabel('False Positive Rate', fontsize=9)
ax.set_ylabel('True Positive Rate', fontsize=9)
ax.set_title('ROC Curves with Confidence Bands', fontsize=10, fontweight='bold')
ax.tick_params(axis='both', labelsize=8)
ax.legend(loc='lower right', fontsize=8)
ax.set_xlim([0, 1])
ax.set_ylim([0, 1])
ax.grid(alpha=0.3)
plt.tight_layout()
# Save both formats
plt.savefig('roc_curves.pdf', bbox_inches='tight')
plt.savefig('roc_curves.png', dpi=600, bbox_inches='tight')
plt.close()
Visual representation of shaded band:
TPR
1 ┤ ╭───────── Upper bound (mean + std)
│ ╭──┤░░░░░░░░░
│ ╭──┤░░░░░░░░░░░│ ← Shaded region shows uncertainty
│ ╭──┤░░░░░░░░░░░░░│
│ ╭──┤░░░░░░░░░░░░░░░─╯
│ ╭──┤░░░░░░░░░░░░░░░╯ ← Mean ROC curve (solid line)
│ ╭──┤░░░░░░░░░░░░░╯
│╭─┤░░░░░░░░░░░░╯ Lower bound (mean - std)
0 ┼──────────────────────────────
0 1 FPR
CRITICAL Sizing Rules:
SQUARE_SIZE = 1.7")aspect='equal' in imshow() to force square cellsleft=0.18, right=0.98, top=0.85, bottom=0.18 to maximize plot areaset_box_aspect(1) - it constrains plot size and wastes spaceimport figure_config # FIRST LINE - registers Helvetica
import numpy as np
import matplotlib.pyplot as plt
SQUARE_SIZE = 1.7 # Standard square panel size
def plot_confusion_matrix(cm, labels, title, output_path=None):
"""
Plot confusion matrix with counts and row percentages.
Uses square cells that fill the figure properly.
"""
cm = np.asarray(cm, dtype=float)
row_sums = cm.sum(axis=1, keepdims=True)
row_sums[row_sums == 0] = 1.0
row_pct = cm / row_sums
# Square figure with tight margins to maximize plot area
fig, ax = plt.subplots(figsize=(SQUARE_SIZE, SQUARE_SIZE))
fig.subplots_adjust(left=0.18, right=0.98, top=0.85, bottom=0.18)
# CRITICAL: Use aspect='equal' for square cells (NOT set_box_aspect)
im = ax.imshow(cm, cmap='Blues', aspect='equal')
ax.set_title(title, fontsize=10, fontweight='bold')
ax.set_xticks(np.arange(len(labels)))
ax.set_yticks(np.arange(len(labels)))
ax.set_xticklabels(labels, fontsize=8)
ax.set_yticklabels(labels, fontsize=8)
ax.set_xlabel('Predicted', fontsize=9)
ax.set_ylabel('Actual', fontsize=9)
max_val = cm.max() if cm.size else 0
for i in range(cm.shape[0]):
for j in range(cm.shape[1]):
count = int(cm[i, j])
pct = row_pct[i, j] * 100
color = 'white' if max_val > 0 and cm[i, j] > max_val / 2 else 'black'
ax.text(j, i, f"{count}\n{pct:.1f}%",
ha='center', va='center', color=color, fontsize=9)
if output_path:
plt.savefig(output_path, dpi=600, bbox_inches='tight')
plt.close()
return fig, ax, im
fig, ax = plt.subplots(figsize=(16, 6))
x_pos = np.arange(len(drugs))
# Plot each model with offset
for i, (model, probs, color, marker) in enumerate(model_data):
offset = (i - 1) * 0.15
mask = true_labels.astype(bool)
colors_arr = np.where(mask, color, 'lightgray')
ax.scatter(x_pos + offset, probs, s=100, c=colors_arr,
marker=marker, edgecolor='black', linewidth=1.5,
zorder=3, label=model)
ax.axhline(0.5, color='red', linestyle='--', linewidth=2, alpha=0.7)
ax.set_ylabel('Probability', fontsize=9)
ax.set_xlabel('Drug', fontsize=9)
ax.set_xticks(x_pos)
ax.set_xticklabels(drugs, rotation=45, ha='right', fontsize=8)
ax.tick_params(axis='both', labelsize=8)
ax.set_ylim(-0.05, 1.15)
ax.grid(axis='y', alpha=0.3)
ax.legend(loc='upper right')
CRITICAL Heatmap Rules:
square=True, never rectangleslinewidths=0 (no white space between cells)import figure_config # FIRST LINE - registers Helvetica
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import re
def clean_concentration_labels(labels):
"""
Clean concentration labels by removing duplicate decimal suffixes.
Examples:
'8.1' → '8' (removes .1 suffix that's just numbering)
'8.2' → '8' (removes .2 suffix)
'0.1' → '0.1' (keeps meaningful decimal - it's the actual value)
'0.1.1' → '0.1' (removes duplicate suffix from 0.1)
"""
cleaned = []
for label in labels:
label_str = str(label)
# Pattern: number.number.number (like 0.1.1) → keep first two parts
if re.match(r'^\d+\.\d+\.\d+$', label_str):
parts = label_str.split('.')
cleaned.append(f"{parts[0]}.{parts[1]}")
# Pattern: integer.single_digit at end (like 8.1, 8.2) → remove suffix
elif re.match(r'^(\d+)\.[1-9]$', label_str):
cleaned.append(re.match(r'^(\d+)\.[1-9]$', label_str).group(1))
else:
cleaned.append(label_str)
return cleaned
# Example: Drug response heatmap
# CRITICAL: Rows = concentrations (Y-axis), Columns = time (X-axis)
n_concentrations = 8
n_timepoints = 40
data_matrix = pd.DataFrame(
np.random.randn(n_concentrations, n_timepoints),
index=[f'{c}' for c in [0, 0.1, 1, 2, 4, 8, 16, 32]], # Concentration labels
columns=[f'{t}' for t in range(0, n_timepoints * 2, 2)] # Time labels (hours)
)
# Clean concentration labels (Y-axis)
y_labels = clean_concentration_labels(data_matrix.index.tolist())
x_labels = data_matrix.columns.tolist() # Time labels (X-axis)
# Setup custom colormap with project colors
from matplotlib.colors import LinearSegmentedColormap
# Project heatmap colors (MANDATORY)
HEATMAP_BLUE = '#123BFF' # Low values
HEATMAP_RED = '#FF2908' # High values
# Create custom diverging colormap: Blue -> White -> Red
cmap = LinearSegmentedColormap.from_list(
'cardiac_rodeo',
[HEATMAP_BLUE, 'white', HEATMAP_RED]
)
cmap.set_bad('white') # NaN values display as white
# MANDATORY Figure size for heatmaps (1:2 ratio)
HEATMAP_HEIGHT = 1.7 # inches
HEATMAP_WIDTH = 3.4 # inches (2x height)
fig, ax = plt.subplots(figsize=(HEATMAP_WIDTH, HEATMAP_HEIGHT))
# Create heatmap with SQUARE cells and NO borders/gaps
sns.heatmap(
data_matrix,
annot=False, # No cell annotations
cmap=cmap,
cbar_kws={'label': 'Response', 'shrink': 0.8},
xticklabels=x_labels, # Time labels (X-axis)
yticklabels=y_labels, # Concentration labels (Y-axis)
square=True, # CRITICAL: Square cells, not rectangles
mask=False,
linewidths=0 # CRITICAL: No borders/gaps between cells
)
# Customize tick labels - FIXED SIZES
ax.set_xticklabels(x_labels, rotation=0, ha='center', fontsize=8)
ax.set_yticklabels(y_labels, fontsize=8, rotation=0)
# CRITICAL: X = Time, Y = Concentration - FIXED FONT SIZES
ax.set_xlabel('Time (Hours)', fontsize=9)
ax.set_ylabel('Concentration (mM)', fontsize=9)
ax.set_title('Drug Response Heatmap', fontsize=10, fontweight='bold')
plt.tight_layout()
Heatmap Axis Orientation (MANDATORY):
X-axis: Time (Hours) →
┌─────────────────────────┐
Y │ 0 2 4 6 ... 78 │
a ├─────────────────────────┤
x │ 32 │ ■ │ ■ │ ■ │ │ ■ │
i │ 16 │ ■ │ ■ │ ■ │ │ ■ │
s │ 8 │ ■ │ ■ │ ■ │ │ ■ │
: │ 4 │ ■ │ ■ │ ■ │ │ ■ │
C │ 2 │ ■ │ ■ │ ■ │ │ ■ │
o │ 1 │ ■ │ ■ │ ■ │ │ ■ │
n │0.1 │ ■ │ ■ │ ■ │ │ ■ │
c │ 0 │ ■ │ ■ │ ■ │ │ ■ │
└─────────────────────────┘
Heatmap Colormap Reference:
| Data Type | Colormap | Setup | Notes |
|---|---|---|---|
| Drug response (DEFAULT) | Custom cardiac_rodeo | See code above | #123BFF (blue) → White → #FF2908 (red) |
| Correlation | Custom or RdBu_r | center=0, vmin=-1, vmax=1 | Symmetric around zero |
| Performance (R², accuracy) | RdYlGn | center=0 | Red (bad) → Yellow → Green (good) |
| Confusion Matrix | Blues | N/A | One-sided, counts only |
| SHAP values | coolwarm | N/A | Blue (neg) → White → Red (pos) |
MANDATORY Heatmap Colors:
HEATMAP_BLUE = '#123BFF' # For low values
HEATMAP_RED = '#FF2908' # For high values
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
# Example data
drugs = ['Amiodarone', 'Bortezomib', 'Chlorpromazine', 'Doxorubicin', 'Erlotinib',
'Ibuprofen', 'Nifedipine', 'Sotalol', 'Vincristine', 'Vioxx']
predictions = np.array([85, 72, 45, 90, 30, 25, 55, 78, 40, 65]) # Probability %
actual_positive = np.array([True, True, False, True, False, False, True, True, False, True])
# Sort drugs: by classification (positive first), then alphabetically
df = pd.DataFrame({'Drug': drugs, 'Pred': predictions, 'Positive': actual_positive})
df = df.sort_values(['Positive', 'Drug'], ascending=[False, True])
drugs_sorted = df['Drug'].tolist()
preds_sorted = df['Pred'].values
status_sorted = df['Positive'].values
# Colors: Green = positive, Red = negative
pos_color = '#2ca02c' # Green
neg_color = '#d62728' # Red
threshold_color = '#1f77b4' # Blue for threshold line
# Create horizontal scatter plot (drugs on Y-axis)
fig, ax = plt.subplots(figsize=(10, max(6, len(drugs) * 0.4)))
positions = np.arange(len(drugs_sorted))
point_colors = [pos_color if s else neg_color for s in status_sorted]
# Scatter: X = probability, Y = drug position
ax.scatter(preds_sorted, positions, c=point_colors, s=100,
edgecolors='black', linewidth=0.5, zorder=3)
# Compute threshold: max(negative samples) + margin, rounded to nearest 5
margin_pp = 2.0
neg_preds = preds_sorted[~status_sorted]
if len(neg_preds) > 0:
threshold = float(np.max(neg_preds)) + margin_pp