omniverse1 commited on
Commit
9b29694
·
verified ·
1 Parent(s): 21bb68b

Update model_handler.py

Browse files
Files changed (1) hide show
  1. model_handler.py +12 -5
model_handler.py CHANGED
@@ -15,7 +15,6 @@ class ModelHandler:
15
  try:
16
  print(f"Loading {self.model_name} on {self.device}...")
17
 
18
- # Pemuatan otomatis oleh pipeline (sudah terbukti berhasil di langkah sebelumnya)
19
  self.pipeline = BaseChronosPipeline.from_pretrained(
20
  self.model_name,
21
  device_map=self.device,
@@ -52,12 +51,20 @@ class ModelHandler:
52
  predictions_samples = self.pipeline.predict(
53
  data['original'],
54
  prediction_length=horizon,
55
- # KOREKSI: Mengganti 'num_samples' menjadi 'n_samples'
56
- n_samples=20
57
  )
58
 
59
- # Mengambil nilai rata-rata (mean) dari semua sampel
60
- mean_predictions = np.mean(predictions_samples, axis=0)
 
 
 
 
 
 
 
 
61
 
62
  return mean_predictions
63
 
 
15
  try:
16
  print(f"Loading {self.model_name} on {self.device}...")
17
 
 
18
  self.pipeline = BaseChronosPipeline.from_pretrained(
19
  self.model_name,
20
  device_map=self.device,
 
51
  predictions_samples = self.pipeline.predict(
52
  data['original'],
53
  prediction_length=horizon,
54
+ # FIX UTAMA: Menghapus 'n_samples' untuk menghindari error.
55
+ # Model akan kembali ke single trajectory prediction (deterministic)
56
  )
57
 
58
+ # Karena sampling dihilangkan, asumsikan output adalah single trajectory (1D atau 2D dengan dimensi pertama 1)
59
+ if predictions_samples.ndim > 1 and predictions_samples.shape[0] > 1:
60
+ # Jika model tetap mengembalikan multiple samples (probabilistic)
61
+ mean_predictions = np.mean(predictions_samples, axis=0)
62
+ elif predictions_samples.ndim > 1 and predictions_samples.shape[0] == 1:
63
+ # Jika hanya satu trajectory
64
+ mean_predictions = predictions_samples[0]
65
+ else:
66
+ # Jika sudah 1D
67
+ mean_predictions = predictions_samples
68
 
69
  return mean_predictions
70