src/sn_comparator.[ch]: Add a user data member.
[sort-networks.git] / src / sn_stage.c
index 7319094..ae61578 100644 (file)
@@ -152,6 +152,7 @@ int sn_stage_comparator_remove (sn_stage_t *s, int c_num)
 sn_stage_t *sn_stage_clone (const sn_stage_t *s)
 {
   sn_stage_t *s_copy;
+  int i;
 
   s_copy = sn_stage_create (s->depth);
   if (s_copy == NULL)
@@ -165,8 +166,13 @@ sn_stage_t *sn_stage_clone (const sn_stage_t *s)
     return (NULL);
   }
 
-  memcpy (s_copy->comparators, s->comparators,
-      s->comparators_num * sizeof (sn_comparator_t));
+  for (i = 0; i < s->comparators_num; i++)
+  {
+    SN_COMP_MIN (s_copy->comparators + i) = SN_COMP_MIN (s->comparators + i);
+    SN_COMP_MAX (s_copy->comparators + i) = SN_COMP_MAX (s->comparators + i);
+    SN_COMP_USER_DATA (s_copy->comparators + i) = NULL;
+    SN_COMP_FREE_FUNC (s_copy->comparators + i) = NULL;
+  }
   s_copy->comparators_num = s->comparators_num;
 
   return (s_copy);
@@ -360,6 +366,47 @@ int sn_stage_cut_at (sn_stage_t *s, int input, enum sn_network_cut_dir_e dir)
   return (new_position);
 } /* int sn_stage_cut_at */
 
+int sn_stage_cut (sn_stage_t *s, int *mask, /* {{{ */
+    sn_stage_t **prev)
+{
+  int i;
+
+  if ((s == NULL) || (mask == NULL) || (prev == NULL))
+    return (EINVAL);
+
+  for (i = 0; i < s->comparators_num; i++)
+  {
+    sn_comparator_t *c = s->comparators + i;
+    int left = SN_COMP_LEFT (c);
+    int right = SN_COMP_RIGHT (c);
+
+    if ((mask[left] == 0)
+        && (mask[right] == 0))
+      continue;
+
+    /* Check if we need to update the cut position */
+    if ((mask[left] != mask[right])
+        && ((mask[left] > 0) || (mask[right] < 0)))
+    {
+      int tmp;
+      int j;
+
+      tmp = mask[right];
+      mask[right] = mask[left];
+      mask[left] = tmp;
+
+      for (j = s->depth - 1; j >= 0; j--)
+        sn_stage_swap (prev[j],
+            left, right);
+    }
+
+    sn_stage_comparator_remove (s, i);
+    i--;
+  } /* for (i = 0; i < s->comparators_num; i++) */
+
+  return (0);
+} /* }}} int sn_stage_cut */
+
 int sn_stage_remove_input (sn_stage_t *s, int input)
 {
   int i;